在上图通过手机的相机拍摄到的物体识别出具体的名称,这个需要通过TensorFlow 训练的模型引用到项目中;以下就是详细地集成 TensorFlow步骤,请按照以下步骤进行操作:
- 在项目的根目录下的 build.gradle 文件中添加 TensorFlow 的 Maven 仓库。在 repositories 部分添加以下行:
allprojects { repositories { // 其他仓库... maven { url 'https://google.bintray.com/tensorflow' } } }
- 在应用的 build.gradle 文件中添加 TensorFlow Lite 的依赖。在 dependencies 部分添加以下行:
implementation 'org.tensorflow:tensorflow-lite:2.5.0'
- 将 TensorFlow Lite 模型文件添加到你的 Android 项目中。将模型文件(.tflite)复制到 app/src/main/assets 目录下。如果 assets 目录不存在,可以手动创建。
- 创建一个 TFLiteObjectDetectionAPIModel类,用于加载和运行 TensorFlow Lite 模型。以下是一个示例代码:
public class TFLiteObjectDetectionAPIModel implements Classifier { private static final Logger LOGGER = new Logger(); // Only return this many results. private static final int NUM_DETECTIONS = 10; // Float model private static final float IMAGE_MEAN = 128.0f; private static final float IMAGE_STD = 128.0f; // Number of threads in the java app private static final int NUM_THREADS = 4; private boolean isModelQuantized; // Config values. private int inputSize; // Pre-allocated buffers. private Vector<String> labels = new Vector<String>(); private int[] intValues; // outputLocations: array of shape [Batchsize, NUM_DETECTIONS,4] // contains the location of detected boxes private float[][][] outputLocations; // outputClasses: array of shape [Batchsize, NUM_DETECTIONS] // contains the classes of detected boxes private float[][] outputClasses; // outputScores: array of shape [Batchsize, NUM_DETECTIONS] // contains the scores of detected boxes private float[][] outputScores; // numDetections: array of shape [Batchsize] // contains the number of detected boxes private float[] numDetections; private ByteBuffer imgData; private Interpreter tfLite; private TFLiteObjectDetectionAPIModel() {} /** Memory-map the model file in Assets. */ private static MappedByteBuffer loadModelFile(AssetManager assets, String modelFilename) throws IOException { AssetFileDescriptor fileDescriptor = assets.openFd(modelFilename); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); long declaredLength = fileDescriptor.getDeclaredLength(); return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); } /** * Initializes a native TensorFlow session for classifying images. * * @param assetManager The asset manager to be used to load assets. * @param modelFilename The filepath of the model GraphDef protocol buffer. * @param labelFilename The filepath of label file for classes. * @param inputSize The size of image input * @param isQuantized Boolean representing model is quantized or not */ public static Classifier create( final AssetManager assetManager, final String modelFilename, final String labelFilename, final int inputSize, final boolean isQuantized) throws IOException { final TFLiteObjectDetectionAPIModel d = new TFLiteObjectDetectionAPIModel(); InputStream labelsInput = null; String actualFilename = labelFilename.split("file:///android_asset/")[1]; labelsInput = assetManager.open(actualFilename); BufferedReader br = null; br = new BufferedReader(new InputStreamReader(labelsInput)); String line; while ((line = br.readLine()) != null) { LOGGER.w(line); d.labels.add(line); } br.close(); d.inputSize = inputSize; try { d.tfLite = new Interpreter(loadModelFile(assetManager, modelFilename)); } catch (Exception e) { throw new RuntimeException(e); } d.isModelQuantized = isQuantized; // Pre-allocate buffers. int numBytesPerChannel; if (isQuantized) { numBytesPerChannel = 1; // Quantized } else { numBytesPerChannel = 4; // Floating point } d.imgData = ByteBuffer.allocateDirect(1 * d.inputSize * d.inputSize * 3 * numBytesPerChannel); d.imgData.order(ByteOrder.nativeOrder()); d.intValues = new int[d.inputSize * d.inputSize]; d.tfLite.setNumThreads(NUM_THREADS); d.outputLocations = new float[1][NUM_DETECTIONS][4]; d.outputClasses = new float[1][NUM_DETECTIONS]; d.outputScores = new float[1][NUM_DETECTIONS]; d.numDetections = new float[1]; return d; } @Override public List<Recognition> recognizeImage(final Bitmap bitmap) { // Log this method so that it can be analyzed with systrace. Trace.beginSection("recognizeImage"); Trace.beginSection("preprocessBitmap"); // Preprocess the image data from 0-255 int to normalized float based // on the provided parameters. bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); imgData.rewind(); for (int i = 0; i < inputSize; ++i) { for (int j = 0; j < inputSize; ++j) { int pixelValue = intValues[i * inputSize + j]; if (isModelQuantized) { // Quantized model imgData.put((byte) ((pixelValue >> 16) & 0xFF)); imgData.put((byte) ((pixelValue >> 8) & 0xFF)); imgData.put((byte) (pixelValue & 0xFF)); } else { // Float model imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); } } } Trace.endSection(); // preprocessBitmap // Copy the input data into TensorFlow. Trace.beginSection("feed"); outputLocations = new float[1][NUM_DETECTIONS][4]; outputClasses = new float[1][NUM_DETECTIONS]; outputScores = new float[1][NUM_DETECTIONS]; numDetections = new float[1]; Object[] inputArray = {imgData}; Map<Integer, Object> outputMap = new HashMap<>(); outputMap.put(0, outputLocations); outputMap.put(1, outputClasses); outputMap.put(2, outputScores); outputMap.put(3, numDetections); Trace.endSection(); // Run the inference call. Trace.beginSection("run"); tfLite.runForMultipleInputsOutputs(inputArray, outputMap); Trace.endSection(); // Show the best detections. // after scaling them back to the input size. final ArrayList<Recognition> recognitions = new ArrayList<>(NUM_DETECTIONS); for (int i = 0; i < NUM_DETECTIONS; ++i) { final RectF detection = new RectF( outputLocations[0][i][1] * inputSize, outputLocations[0][i][0] * inputSize, outputLocations[0][i][3] * inputSize, outputLocations[0][i][2] * inputSize); // SSD Mobilenet V1 Model assumes class 0 is background class // in label file and class labels start from 1 to number_of_classes+1, // while outputClasses correspond to class index from 0 to number_of_classes int labelOffset = 1; recognitions.add( new Recognition( "" + i, labels.get((int) outputClasses[0][i] + labelOffset), outputScores[0][i], detection)); } Trace.endSection(); // "recognizeImage" return recognitions; } @Override public void enableStatLogging(final boolean logStats) {} @Override public String getStatString() { return ""; } @Override public void close() {} public void setNumThreads(int num_threads) { if (tfLite != null) tfLite.setNumThreads(num_threads); } @Override public void setUseNNAPI(boolean isChecked) { if (tfLite != null) tfLite.setUseNNAPI(isChecked); } }
确保替换 modelPath 参数为你的模型文件在 assets 目录中的路径。
- 在你的应用程序中使用 TFLiteObjectDetectionAPIModel 类进行推理。以下是一个简单的示例:
@Override public void onPreviewSizeChosen(final Size size, final int rotation) { final float textSizePx = TypedValue.applyDimension( TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics()); borderedText = new BorderedText(textSizePx); borderedText.setTypeface(Typeface.MONOSPACE); tracker = new MultiBoxTracker(this); int cropSize = TF_OD_API_INPUT_SIZE; try { detector = TFLiteObjectDetectionAPIModel.create( getAssets(), TF_OD_API_MODEL_FILE, TF_OD_API_LABELS_FILE, TF_OD_API_INPUT_SIZE, TF_OD_API_IS_QUANTIZED); cropSize = TF_OD_API_INPUT_SIZE; } catch (final IOException e) { e.printStackTrace(); LOGGER.e(e, "Exception initializing classifier!"); Toast toast = Toast.makeText( getApplicationContext(), "Classifier could not be initialized", Toast.LENGTH_SHORT); toast.show(); finish(); } // 解析输出数据 // ...
根据你的模型和任务,你可能需要根据模型的规范和文档来解析输出数据。
输出解析文本数据
需要项目源码私聊