Android 中集成 TensorFlow Lite图片识别

简介: Android 中集成 TensorFlow Lite图片识别

在上图通过手机的相机拍摄到的物体识别出具体的名称,这个需要通过TensorFlow 训练的模型引用到项目中;以下就是详细地集成 TensorFlow步骤,请按照以下步骤进行操作:

  1. 在项目的根目录下的 build.gradle 文件中添加 TensorFlow 的 Maven 仓库。在 repositories 部分添加以下行:


allprojects {
    repositories {
        // 其他仓库...
        maven {
            url 'https://google.bintray.com/tensorflow'
        }
    }
}
  1. 在应用的 build.gradle 文件中添加 TensorFlow Lite 的依赖。在 dependencies 部分添加以下行:
implementation 'org.tensorflow:tensorflow-lite:2.5.0'
  1. 将 TensorFlow Lite 模型文件添加到你的 Android 项目中。将模型文件(.tflite)复制到 app/src/main/assets 目录下。如果 assets 目录不存在,可以手动创建。
  2. 创建一个 TFLiteObjectDetectionAPIModel类,用于加载和运行 TensorFlow Lite 模型。以下是一个示例代码:


public class TFLiteObjectDetectionAPIModel implements Classifier {
  private static final Logger LOGGER = new Logger();
  // Only return this many results.
  private static final int NUM_DETECTIONS = 10;
  // Float model
  private static final float IMAGE_MEAN = 128.0f;
  private static final float IMAGE_STD = 128.0f;
  // Number of threads in the java app
  private static final int NUM_THREADS = 4;
  private boolean isModelQuantized;
  // Config values.
  private int inputSize;
  // Pre-allocated buffers.
  private Vector<String> labels = new Vector<String>();
  private int[] intValues;
  // outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4]
  // contains the location of detected boxes
  private float[][][] outputLocations;
  // outputClasses: array of shape [Batchsize, NUM_DETECTIONS]
  // contains the classes of detected boxes
  private float[][] outputClasses;
  // outputScores: array of shape [Batchsize, NUM_DETECTIONS]
  // contains the scores of detected boxes
  private float[][] outputScores;
  // numDetections: array of shape [Batchsize]
  // contains the number of detected boxes
  private float[] numDetections;
  private ByteBuffer imgData;
  private Interpreter tfLite;
  private TFLiteObjectDetectionAPIModel() {}
  /** Memory-map the model file in Assets. */
  private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename)
      throws IOException {
    AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  }
  /**
   * Initializes a native TensorFlow session for classifying images.
   *
   * @param assetManager The asset manager to be used to load assets.
   * @param modelFilename The filepath of the model GraphDef protocol buffer.
   * @param labelFilename The filepath of label file for classes.
   * @param inputSize The size of image input
   * @param isQuantized Boolean representing model is quantized or not
   */
  public static Classifier create(
      final AssetManager assetManager,
      final String modelFilename,
      final String labelFilename,
      final int inputSize,
      final boolean isQuantized)
      throws IOException {
    final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel();
    InputStream labelsInput = null;
    String actualFilename = labelFilename.split("file:///android_asset/")[1];
    labelsInput = assetManager.open(actualFilename);
    BufferedReader br = null;
    br = new BufferedReader(new InputStreamReader(labelsInput));
    String line;
    while ((line = br.readLine()) != null) {
      LOGGER.w(line);
      d.labels.add(line);
    }
    br.close();
    d.inputSize = inputSize;
    try {
      d.tfLite = new Interpreter(loadModelFile(assetManager, modelFilename));
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
    d.isModelQuantized = isQuantized;
    // Pre-allocate buffers.
    int numBytesPerChannel;
    if (isQuantized) {
      numBytesPerChannel = 1; // Quantized
    } else {
      numBytesPerChannel = 4; // Floating point
    }
    d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel);
    d.imgData.order(ByteOrder.nativeOrder());
    d.intValues = new int[d.inputSize * d.inputSize];
    d.tfLite.setNumThreads(NUM_THREADS);
    d.outputLocations = new float[1][NUM_DETECTIONS][4];
    d.outputClasses = new float[1][NUM_DETECTIONS];
    d.outputScores = new float[1][NUM_DETECTIONS];
    d.numDetections = new float[1];
    return d;
  }
  @Override
  public List<Recognition> recognizeImage(final Bitmap bitmap) {
    // Log this method so that it can be analyzed with systrace.
    Trace.beginSection("recognizeImage");
    Trace.beginSection("preprocessBitmap");
    // Preprocess the image data from 0-255 int to normalized float based
    // on the provided parameters.
    bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
    imgData.rewind();
    for (int i = 0; i < inputSize; ++i) {
      for (int j = 0; j < inputSize; ++j) {
        int pixelValue = intValues[i * inputSize + j];
        if (isModelQuantized) {
          // Quantized model
          imgData.put((byte) ((pixelValue >> 16) & 0xFF));
          imgData.put((byte) ((pixelValue >> 8) & 0xFF));
          imgData.put((byte) (pixelValue & 0xFF));
        } else { // Float model
          imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
          imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
          imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
        }
      }
    }
    Trace.endSection(); // preprocessBitmap
    // Copy the input data into TensorFlow.
    Trace.beginSection("feed");
    outputLocations = new float[1][NUM_DETECTIONS][4];
    outputClasses = new float[1][NUM_DETECTIONS];
    outputScores = new float[1][NUM_DETECTIONS];
    numDetections = new float[1];
    Object[] inputArray = {imgData};
    Map<Integer, Object> outputMap = new HashMap<>();
    outputMap.put(0, outputLocations);
    outputMap.put(1, outputClasses);
    outputMap.put(2, outputScores);
    outputMap.put(3, numDetections);
    Trace.endSection();
    // Run the inference call.
    Trace.beginSection("run");
    tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
    Trace.endSection();
    // Show the best detections.
    // after scaling them back to the input size.
    final ArrayList<Recognition> recognitions = new ArrayList<>(NUM_DETECTIONS);
    for (int i = 0; i < NUM_DETECTIONS; ++i) {
      final RectF detection =
          new RectF(
              outputLocations[0][i][1] * inputSize,
              outputLocations[0][i][0] * inputSize,
              outputLocations[0][i][3] * inputSize,
              outputLocations[0][i][2] * inputSize);
      // SSD Mobilenet V1 Model assumes class 0 is background class
      // in label file and class labels start from 1 to number_of_classes+1,
      // while outputClasses correspond to class index from 0 to number_of_classes
      int labelOffset = 1;
      recognitions.add(
          new Recognition(
              "" + i,
              labels.get((int) outputClasses[0][i] + labelOffset),
              outputScores[0][i],
              detection));
    }
    Trace.endSection(); // "recognizeImage"
    return recognitions;
  }
  @Override
  public void enableStatLogging(final boolean logStats) {}
  @Override
  public String getStatString() {
    return "";
  }
  @Override
  public void close() {}
  public void setNumThreads(int num_threads) {
    if (tfLite != null) tfLite.setNumThreads(num_threads);
  }
  @Override
  public void setUseNNAPI(boolean isChecked) {
    if (tfLite != null) tfLite.setUseNNAPI(isChecked);
  }
}

确保替换 modelPath 参数为你的模型文件在 assets 目录中的路径。

  1. 在你的应用程序中使用 TFLiteObjectDetectionAPIModel 类进行推理。以下是一个简单的示例:


@Override
public void onPreviewSizeChosen(final Size size, final int rotation) {
  final float textSizePx =
      TypedValue.applyDimension(
          TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
  borderedText = new BorderedText(textSizePx);
  borderedText.setTypeface(Typeface.MONOSPACE);
  tracker = new MultiBoxTracker(this);
  int cropSize = TF_OD_API_INPUT_SIZE;
  try {
    detector =
        TFLiteObjectDetectionAPIModel.create(
            getAssets(),
            TF_OD_API_MODEL_FILE,
            TF_OD_API_LABELS_FILE,
            TF_OD_API_INPUT_SIZE,
            TF_OD_API_IS_QUANTIZED);
    cropSize = TF_OD_API_INPUT_SIZE;
  } catch (final IOException e) {
    e.printStackTrace();
    LOGGER.e(e, "Exception initializing classifier!");
    Toast toast =
        Toast.makeText(
            getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT);
    toast.show();
    finish();
  }
  // 解析输出数据
// ...

根据你的模型和任务,你可能需要根据模型的规范和文档来解析输出数据。

输出解析文本数据


需要项目源码私聊

相关文章
|
15小时前
|
存储 监控 安全
打造高效移动办公环境:Android与iOS平台的集成策略
【5月更文挑战第15天】 在数字化时代,移动办公不再是一种奢望,而是日常工作的必需。随着智能手机和平板电脑的性能提升,Android与iOS设备已成为职场人士的重要工具。本文深入探讨了如何通过技术整合,提高两大移动平台在企业环境中的协同工作能力,重点分析了各自平台上的系统集成策略、安全性考虑以及跨平台协作工具的应用。通过对现有技术的剖析与案例研究,旨在为读者提供一套实用的移动办公解决方案。
|
15小时前
|
安全 物联网 Android开发
构建未来:Android与IoT设备的无缝集成
【5月更文挑战第10天】 在数字化时代的浪潮中,智能设备与互联网的结合日益紧密。本文深入探讨了Android系统如何通过其开放性和灵活性成为连接物联网(IoT)设备的关键枢纽。我们将分析Android平台与IoT设备集成的技术途径,探索它们如何共同塑造智能家居、可穿戴技术以及工业自动化等领域的未来。文中不仅阐述了当前的发展状况,还展望了未来的发展趋势,特别是安全性和隐私保护方面的挑战及对策。
10 1
|
15小时前
|
Android开发
Android 高通平台集成无源码apk示例
Android 高通平台集成无源码apk示例
16 0
|
15小时前
|
Android开发
Android 集成vendor下的模块
Android 集成vendor下的模块
12 0
|
15小时前
|
传感器 Java 开发工具
[NDK/JNI系列03] Android Studio集成NDK开发环境
[NDK/JNI系列03] Android Studio集成NDK开发环境
23 0
|
15小时前
|
机器学习/深度学习 PyTorch TensorFlow
NumPy与TensorFlow/PyTorch的集成实践
【4月更文挑战第17天】本文探讨了NumPy与主流深度学习框架TensorFlow和PyTorch的集成实践,阐述了它们如何通过便捷的数据转换提升开发效率和模型性能。在TensorFlow中,NumPy数组可轻松转为Tensor,反之亦然,便于原型设计和大规模训练。PyTorch的张量与NumPy数组在内存中共享,实现无缝转换。尽管集成带来了性能和内存管理的考量,但这种结合为机器学习流程提供了强大支持,促进了AI技术的发展。
|
15小时前
|
移动开发 监控 安全
mPaaS常见问题之Android集成dexPatch热修复运行时候无法正常进行热更新如何解决
mPaaS(移动平台即服务,Mobile Platform as a Service)是阿里巴巴集团提供的一套移动开发解决方案,它包含了一系列移动开发、测试、监控和运营的工具和服务。以下是mPaaS常见问题的汇总,旨在帮助开发者和企业用户解决在使用mPaaS产品过程中遇到的各种挑战
39 0
|
15小时前
|
前端开发 Java Maven
java集成opencv(不踩坑),实现人脸检测小demo(含上传人像图片识别接口),windows,IDEA,Springboot
java集成opencv(不踩坑),实现人脸检测小demo(含上传人像图片识别接口),windows,IDEA,Springboot
209 0
|
15小时前
|
机器学习/深度学习 Dart TensorFlow
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(5)
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(5)
75 0
|
15小时前
|
机器学习/深度学习 存储 编解码
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(4)
TensorFlow Lite,ML Kit 和 Flutter 移动深度学习:6~11(4)
125 0