TensorFlow Lite
TensorFlow Lite 是一种用于设备端推断的开源深度学习框架,
在移动设备和 IoT 设备上部署机器学习模型
环境
AndroidStudio 4.0 + JAVA
数字分类器
通过 TensorFlow Lite 模型对手写数字进行分类。
关于DEMO
Github demo 源码(kotlin)
上面图片的DEMO Github demo 源码(Java)
历程
刚开始并没有找到图片中的DEMO源码, 于是自己根据kotlin的DEMO移植了一下, 以下是移植的过程, 若对TensorFlow Lite已有所了解, 请自行跳过.
在AS中新建Module DigitClassifierByTFL.
编译环境
build.gradle中SDK相关配置:
android { compileSdkVersion 30 buildToolsVersion "30.0.2" defaultConfig { applicationId "com.ansondroider.digitclassifierbytfl" minSdkVersion 16 targetSdkVersion 16 versionCode 1 versionName "1.0" } }
目录结构
等待构建完成, 需修改一些配置:
build.gradle: 不压缩 .tflite 文件, 若不加会因为导入模型有问题导致运行出错
aaptOptions { noCompress "tflite" }
build.gradle: 增加 TensorFlow Lite依赖
dependencies { implementation fileTree(dir: "libs", include: ["*.jar"]) implementation ('org.tensorflow:tensorflow-lite:0.0.0-nightly'){changing = true} }
源码及说明
layout
<?xml version="1.0" encoding="utf-8"?> <RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android" xmlns:tools="http://schemas.android.com/tools" android:layout_width="match_parent" android:layout_height="match_parent"> <com.ansondroider.digitclassifierbytfl.PaintView android:layout_width="400dp" android:layout_height="400dp" android:layout_centerHorizontal="true" android:id="@+id/paintView"/> <TextView android:id="@+id/tvRes" android:layout_width="match_parent" android:layout_height="wrap_content" android:layout_alignParentBottom="true" android:textColor="#FF00FF00" android:text="Result: ?" android:textSize="30sp"/> </RelativeLayout>
PaintView: 手指绘画.
TextView: 显示结果.
Activity
package com.ansondroider.digitclassifierbytfl; import android.Manifest; import android.app.Activity; import android.content.res.AssetFileDescriptor; import android.content.res.AssetManager; import android.graphics.Bitmap; import android.os.Build; import android.os.Bundle; import android.widget.TextView; import org.tensorflow.lite.Interpreter; import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.channels.FileChannel; public class DigitClassifierByTFL extends Activity { PaintView paintView; TextView tvRes; Classifier fier; @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); //初始化UI setContentView(R.layout.activity_digit_classifier_by_tfl); tvRes = (TextView)findViewById(R.id.tvRes); paintView = (PaintView)findViewById(R.id.paintView); //添加PaintView回调, 当手指绘画完成后, 立即调用分类器进行分类. //绘画完成: 当手指抬起后 500 毫秒. paintView.setCallback(new PaintView.Callback() { @Override public void onWriteDone() { final String res = fier.classifier(paintView.getBitmap()); tvRes.post(new Runnable() { @Override public void run() { tvRes.setText(res); } }); } }); //创建分类器 fier = new Classifier(getAssets()); if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { requestPermissions(new String[]{Manifest.permission.WRITE_EXTERNAL_STORAGE, Manifest.permission.READ_EXTERNAL_STORAGE}, 0x77); } } @Override protected void onDestroy() { super.onDestroy(); fier.close(); }
PaintView
package com.ansondroider.digitclassifierbytfl; import android.content.Context; import android.graphics.Bitmap; import android.graphics.Canvas; import android.graphics.Color; import android.graphics.Paint; import android.graphics.Path; import android.util.AttributeSet; import android.view.MotionEvent; import android.view.View; public class PaintView extends View { public PaintView(Context context) { super(context); } public PaintView(Context context, AttributeSet attrs) { super(context, attrs); } public PaintView(Context context, AttributeSet attrs, int defStyleAttr) { super(context, attrs, defStyleAttr); } Bitmap bm; public Bitmap getBitmap(){ return bm; } @Override protected void onSizeChanged(int w, int h, int oldw, int oldh) { super.onSizeChanged(w, h, oldw, oldh); if(bm != null){ bm.recycle(); } bm = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888); cBm = new Canvas(bm); cBm.drawColor(Color.BLACK); } Canvas cBm; Paint p = new Paint(Paint.ANTI_ALIAS_FLAG); float dx, dy, cx, cy; @Override public boolean onTouchEvent(MotionEvent event) { cx = event.getX(); cy = event.getY(); switch(event.getAction()){ case MotionEvent.ACTION_DOWN: dx = cx; dy = cy; startWrite(); break; case MotionEvent.ACTION_MOVE: onMove(); break; case MotionEvent.ACTION_CANCEL: case MotionEvent.ACTION_UP: endWrite(); break; } postInvalidate(); return true; } Path path = new Path(); void startWrite(){ removeCallbacks(writeDone); path.moveTo(cx, cy); //cBm.drawColor(Color.WHITE); } void onMove(){ path.lineTo(cx, cy); cBm.drawColor(Color.BLACK); p.setStyle(Paint.Style.STROKE); p.setColor(Color.WHITE); p.setStrokeWidth(50); cBm.drawPath(path, p); } Runnable writeDone = new Runnable() { @Override public void run() { if(cb != null)cb.onWriteDone(); path.reset(); } }; void endWrite(){ removeCallbacks(writeDone); postDelayed(writeDone, 500); } @Override protected void onDraw(Canvas canvas) { if(bm != null && !bm.isRecycled())canvas.drawBitmap(bm, 0, 0, p); } Callback cb; public void setCallback(Callback c){ cb = c; } public interface Callback{ void onWriteDone(); } }
分类器
class Classifier{ //mnist.tflite: 来自kotlin DEMO, 识别率很低, 文件小. //mnist_big.tflite: 来自JAVA DEMO, 识别率高,文件大 final String MODEL = "mnist_big.tflite"; //插入器 Interpreter interpreter; //输入识别的图像尺寸 int bmWidth, bmHeight; //用于读取tflite文件 AssetManager asset; //用于创建ByteBuffer int modelInputSize; Classifier(AssetManager asset){ this.asset = asset; //创建插入器. Interpreter.Options op = new Interpreter.Options(); op.setUseNNAPI(true); interpreter = new Interpreter(loadModel(), op); //获取输入信息 int[] shape = interpreter.getInputTensor(0).shape(); bmWidth = shape[1]; bmHeight = shape[2]; //计算ByteBuffer大小. int FLOAT_TYPE_SIZE = 4; int PIXEL_SIZE = 1; modelInputSize = FLOAT_TYPE_SIZE * bmWidth * bmHeight * PIXEL_SIZE; } //加载模型 ByteBuffer loadModel(){ try { AssetFileDescriptor fd = asset.openFd(MODEL); FileInputStream is = new FileInputStream(fd.getFileDescriptor()); FileChannel channel = is.getChannel(); long startOffset = fd.getStartOffset(); long declareLength = fd.getDeclaredLength(); return channel.map(FileChannel.MapMode.READ_ONLY, startOffset, declareLength); } catch (IOException e) { e.printStackTrace(); } return null; } //执行分类 String classifier(Bitmap bm){ //缩放图片到指定尺寸.(28*28) Bitmap nbm = Bitmap.createScaledBitmap(bm, bmWidth, bmHeight, true); ByteBuffer byteBuffer = convertBitmapToByteBuffer(nbm); //Kotlin中的代码: // val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) } // 平时不用它, 看这代码头痛了好久. //若创建的数组不对, 如用float[10], 或float[2][10] //则会导致异常(Google 百度都不知道): /** 2020-09-10 10:55:18.213 14429-14429/com.ansondroider.digitclassifierbytfl E/AndroidRuntime: FATAL EXCEPTION: main Process: com.ansondroider.digitclassifierbytfl, PID: 14429 java.lang.IllegalArgumentException: Cannot copy from a TensorFlowLite tensor (softmax_tensor) with shape [1, 10] to a Java object with shape [2, 10]. at org.tensorflow.lite.Tensor.throwIfDstShapeIsIncompatible(Tensor.java:482) at org.tensorflow.lite.Tensor.copyTo(Tensor.java:252) at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:170) at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:347) at org.tensorflow.lite.Interpreter.run(Interpreter.java:306) at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$Classifier.classifier(DigitClassifierByTFL.java:98) at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$1.onWriteDone(DigitClassifierByTFL.java:33) at com.ansondroider.digitclassifierbytfl.PaintView$1.run(PaintView.java:86) at android.os.Handler.handleCallback(Handler.java:883) at android.os.Handler.dispatchMessage(Handler.java:100) at android.os.Looper.loop(Looper.java:214) at android.app.ActivityThread.main(ActivityThread.java:7356) at java.lang.reflect.Method.invoke(Native Method) at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:492) at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930)**/ float[][] result = new float[1][10]; //执行. interpreter.run(byteBuffer, result); //格式化输出 return getOutputSting(result); } String getOutputSting(float[][] floats){ //Kotlin 代码: // val maxIndex = output.indices.maxBy { output[it] } ?: -1 // return "Prediction Result: %d\nConfidence: %2f".format(maxIndex, output[maxIndex]) //在float[10]数组中, 存放了推算的结果, 下标分别对应的是[0,9]的数字. //只需要遍历10个数中, 找出最大的值即可. StringBuilder result = new StringBuilder("Result:\n"); float[] res = floats[0]; float max = -1; int v = -1; for(int i = 0; i < res.length; i ++){ result.append("[" + i + "]=" + res[i]).append("\n"); if(max < res[i]){ max = res[i]; v = i; } } result.append("BEST: " + v); return result.toString(); } ByteBuffer convertBitmapToByteBuffer(Bitmap bm){ //刚开始, 用错了函数接口: ByteBuffer.allocate //这样会导致推算的结果不管输入如何变化, 都输出固定的float[10] //在打开后一直不变, 而在调试过程中, 也出现过多次生新运行都显示同样的结果. ByteBuffer bf = ByteBuffer.allocateDirect(modelInputSize); bf.order(ByteOrder.nativeOrder()); int[] pixels = new int[bmWidth * bmHeight]; bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight()); for(int i = 0; i < pixels.length; i ++){ int r = (pixels[i] >> 16) & 0xFF; int g = (pixels[i] >> 8) & 0xFF; int b = pixels[i] & 0xFF; float normalizePixelValue = (r + g + b) / 3f / 255f; bf.putFloat(normalizePixelValue); } return bf; } void close(){ interpreter.close(); } } }
DEMO
Demo源码下载
相关
TensorFlow Lite(Jinpeng)
TensorFlow Lite 指南
TensorFlow Lite 示例应用
Kotlin入门(4)声明与操作数组
Kotlin Array
google / tensorflow / tensorflow-lite