Android TensorFlow Lite 初探 数字分类器(JAVA DEMO)

简介: Android TensorFlow Lite 初探 数字分类器(JAVA DEMO)

TensorFlow Lite


TensorFlow Lite 是一种用于设备端推断的开源深度学习框架,

在移动设备和 IoT 设备上部署机器学习模型


环境


AndroidStudio 4.0 + JAVA


数字分类器


通过 TensorFlow Lite 模型对手写数字进行分类。

image.png


关于DEMO


Github demo 源码(kotlin)

上面图片的DEMO Github demo 源码(Java)


历程


刚开始并没有找到图片中的DEMO源码, 于是自己根据kotlin的DEMO移植了一下, 以下是移植的过程, 若对TensorFlow Lite已有所了解, 请自行跳过.


在AS中新建Module DigitClassifierByTFL.

编译环境

image.png

build.gradle中SDK相关配置:

android {
    compileSdkVersion 30
    buildToolsVersion "30.0.2"
    defaultConfig {
        applicationId "com.ansondroider.digitclassifierbytfl"
        minSdkVersion 16
        targetSdkVersion 16
        versionCode 1
        versionName "1.0"
    }
}


目录结构

image.png

等待构建完成, 需修改一些配置:

build.gradle: 不压缩 .tflite 文件, 若不加会因为导入模型有问题导致运行出错

aaptOptions {
        noCompress "tflite"
    }


build.gradle: 增加 TensorFlow Lite依赖

dependencies {
    implementation fileTree(dir: "libs", include: ["*.jar"])
    implementation ('org.tensorflow:tensorflow-lite:0.0.0-nightly'){changing = true}
}


源码及说明

layout

<?xml version="1.0" encoding="utf-8"?>
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent">
    <com.ansondroider.digitclassifierbytfl.PaintView
        android:layout_width="400dp"
        android:layout_height="400dp"
        android:layout_centerHorizontal="true"
        android:id="@+id/paintView"/>
    <TextView
        android:id="@+id/tvRes"
        android:layout_width="match_parent"
        android:layout_height="wrap_content"
        android:layout_alignParentBottom="true"
        android:textColor="#FF00FF00"
        android:text="Result: ?"
        android:textSize="30sp"/>
</RelativeLayout>


PaintView: 手指绘画.

TextView: 显示结果.

image.pngimage.png


Activity

package com.ansondroider.digitclassifierbytfl;
import android.Manifest;
import android.app.Activity;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.graphics.Bitmap;
import android.os.Build;
import android.os.Bundle;
import android.widget.TextView;
import org.tensorflow.lite.Interpreter;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.channels.FileChannel;
public class DigitClassifierByTFL extends Activity {
  PaintView paintView;
  TextView tvRes;
  Classifier fier;
  @Override
  protected void onCreate(Bundle savedInstanceState) {
      super.onCreate(savedInstanceState);
      //初始化UI
      setContentView(R.layout.activity_digit_classifier_by_tfl);
      tvRes = (TextView)findViewById(R.id.tvRes);
      paintView = (PaintView)findViewById(R.id.paintView);
      //添加PaintView回调, 当手指绘画完成后, 立即调用分类器进行分类.
      //绘画完成: 当手指抬起后 500 毫秒.
      paintView.setCallback(new PaintView.Callback() {
          @Override
          public void onWriteDone() {
              final String res = fier.classifier(paintView.getBitmap());
              tvRes.post(new Runnable() {
                  @Override
                  public void run() {
                      tvRes.setText(res);
                  }
              });
          }
      });
      //创建分类器
      fier = new Classifier(getAssets());
      if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
          requestPermissions(new String[]{Manifest.permission.WRITE_EXTERNAL_STORAGE, Manifest.permission.READ_EXTERNAL_STORAGE}, 0x77);
      }
  }
  @Override
  protected void onDestroy() {
      super.onDestroy();
      fier.close();
  }


PaintView

package com.ansondroider.digitclassifierbytfl;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Paint;
import android.graphics.Path;
import android.util.AttributeSet;
import android.view.MotionEvent;
import android.view.View;
public class PaintView extends View {
  public PaintView(Context context) {
      super(context);
  }
  public PaintView(Context context, AttributeSet attrs) {
      super(context, attrs);
  }
  public PaintView(Context context, AttributeSet attrs, int defStyleAttr) {
      super(context, attrs, defStyleAttr);
  }
  Bitmap bm;
  public Bitmap getBitmap(){
      return bm;
  }
  @Override
  protected void onSizeChanged(int w, int h, int oldw, int oldh) {
      super.onSizeChanged(w, h, oldw, oldh);
      if(bm != null){
          bm.recycle();
      }
      bm = Bitmap.createBitmap(w, h, Bitmap.Config.ARGB_8888);
      cBm = new Canvas(bm);
      cBm.drawColor(Color.BLACK);
  }
  Canvas cBm;
  Paint p = new Paint(Paint.ANTI_ALIAS_FLAG);
  float dx, dy, cx, cy;
  @Override
  public boolean onTouchEvent(MotionEvent event) {
      cx = event.getX();
      cy = event.getY();
      switch(event.getAction()){
          case MotionEvent.ACTION_DOWN:
              dx = cx;
              dy = cy;
              startWrite();
              break;
          case MotionEvent.ACTION_MOVE:
              onMove();
              break;
          case MotionEvent.ACTION_CANCEL:
          case MotionEvent.ACTION_UP:
              endWrite();
              break;
      }
      postInvalidate();
      return true;
  }
  Path path = new Path();
  void startWrite(){
      removeCallbacks(writeDone);
      path.moveTo(cx, cy);
      //cBm.drawColor(Color.WHITE);
  }
  void onMove(){
      path.lineTo(cx, cy);
      cBm.drawColor(Color.BLACK);
      p.setStyle(Paint.Style.STROKE);
      p.setColor(Color.WHITE);
      p.setStrokeWidth(50);
      cBm.drawPath(path, p);
  }
  Runnable writeDone = new Runnable() {
      @Override
      public void run() {
          if(cb != null)cb.onWriteDone();
          path.reset();
      }
  };
  void endWrite(){
      removeCallbacks(writeDone);
      postDelayed(writeDone, 500);
  }
  @Override
  protected void onDraw(Canvas canvas) {
      if(bm != null && !bm.isRecycled())canvas.drawBitmap(bm, 0, 0, p);
  }
  Callback cb;
  public void setCallback(Callback c){
      cb = c;
  }
  public interface Callback{
      void onWriteDone();
  }
}


分类器

class Classifier{
     //mnist.tflite: 来自kotlin DEMO, 识别率很低, 文件小.
     //mnist_big.tflite: 来自JAVA DEMO, 识别率高,文件大
      final String MODEL = "mnist_big.tflite";
      //插入器
      Interpreter interpreter;
      //输入识别的图像尺寸
      int bmWidth, bmHeight;
      //用于读取tflite文件
      AssetManager asset;
      //用于创建ByteBuffer
      int modelInputSize;
      Classifier(AssetManager asset){
          this.asset = asset;
      //创建插入器.
          Interpreter.Options op = new Interpreter.Options();
          op.setUseNNAPI(true);
          interpreter = new Interpreter(loadModel(), op);
    //获取输入信息
          int[] shape = interpreter.getInputTensor(0).shape();
          bmWidth = shape[1];
          bmHeight = shape[2];
    //计算ByteBuffer大小.
          int FLOAT_TYPE_SIZE = 4;
          int PIXEL_SIZE = 1;
          modelInputSize = FLOAT_TYPE_SIZE * bmWidth * bmHeight * PIXEL_SIZE;
      }
     //加载模型
      ByteBuffer loadModel(){
          try {
              AssetFileDescriptor fd = asset.openFd(MODEL);
              FileInputStream is = new FileInputStream(fd.getFileDescriptor());
              FileChannel channel = is.getChannel();
              long startOffset = fd.getStartOffset();
              long declareLength = fd.getDeclaredLength();
              return channel.map(FileChannel.MapMode.READ_ONLY, startOffset, declareLength);
          } catch (IOException e) {
              e.printStackTrace();
          }
          return null;
      }
     //执行分类
      String classifier(Bitmap bm){
          //缩放图片到指定尺寸.(28*28)
          Bitmap nbm = Bitmap.createScaledBitmap(bm, bmWidth, bmHeight, true);
          ByteBuffer byteBuffer = convertBitmapToByteBuffer(nbm);
      //Kotlin中的代码:
      //  val result = Array(1) { FloatArray(OUTPUT_CLASSES_COUNT) }
      //  平时不用它, 看这代码头痛了好久.
      //若创建的数组不对, 如用float[10], 或float[2][10]
      //则会导致异常(Google 百度都不知道):
      /**  2020-09-10 10:55:18.213 14429-14429/com.ansondroider.digitclassifierbytfl E/AndroidRuntime: FATAL EXCEPTION: main
  Process: com.ansondroider.digitclassifierbytfl, PID: 14429
  java.lang.IllegalArgumentException: Cannot copy from a TensorFlowLite tensor (softmax_tensor) with shape [1, 10] to a Java object with shape [2, 10].
      at org.tensorflow.lite.Tensor.throwIfDstShapeIsIncompatible(Tensor.java:482)
      at org.tensorflow.lite.Tensor.copyTo(Tensor.java:252)
      at org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:170)
      at org.tensorflow.lite.Interpreter.runForMultipleInputsOutputs(Interpreter.java:347)
      at org.tensorflow.lite.Interpreter.run(Interpreter.java:306)
      at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$Classifier.classifier(DigitClassifierByTFL.java:98)
      at com.ansondroider.digitclassifierbytfl.DigitClassifierByTFL$1.onWriteDone(DigitClassifierByTFL.java:33)
      at com.ansondroider.digitclassifierbytfl.PaintView$1.run(PaintView.java:86)
      at android.os.Handler.handleCallback(Handler.java:883)
      at android.os.Handler.dispatchMessage(Handler.java:100)
      at android.os.Looper.loop(Looper.java:214)
      at android.app.ActivityThread.main(ActivityThread.java:7356)
      at java.lang.reflect.Method.invoke(Native Method)
      at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:492)
      at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:930)**/
          float[][] result = new float[1][10];
          //执行.
          interpreter.run(byteBuffer, result);
          //格式化输出
          return getOutputSting(result);
      }
      String getOutputSting(float[][] floats){
        //Kotlin 代码:
        //  val maxIndex = output.indices.maxBy { output[it] } ?: -1
      //  return "Prediction Result: %d\nConfidence: %2f".format(maxIndex, output[maxIndex])
    //在float[10]数组中, 存放了推算的结果, 下标分别对应的是[0,9]的数字.
    //只需要遍历10个数中, 找出最大的值即可.
          StringBuilder result = new StringBuilder("Result:\n");
          float[] res = floats[0];
          float max = -1;
          int v = -1;
          for(int i = 0; i < res.length; i ++){
              result.append("[" + i + "]=" + res[i]).append("\n");
              if(max < res[i]){
                  max = res[i];
                  v = i;
              }
          }
          result.append("BEST: " + v);
          return result.toString();
      }
      ByteBuffer convertBitmapToByteBuffer(Bitmap bm){
          //刚开始, 用错了函数接口: ByteBuffer.allocate
          //这样会导致推算的结果不管输入如何变化, 都输出固定的float[10]
          //在打开后一直不变, 而在调试过程中, 也出现过多次生新运行都显示同样的结果.
          ByteBuffer bf = ByteBuffer.allocateDirect(modelInputSize);
          bf.order(ByteOrder.nativeOrder());
          int[] pixels = new int[bmWidth * bmHeight];
          bm.getPixels(pixels, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight());
          for(int i = 0; i < pixels.length; i ++){
              int r = (pixels[i] >> 16) & 0xFF;
              int g = (pixels[i] >> 8) & 0xFF;
              int b = pixels[i] & 0xFF;
              float normalizePixelValue = (r + g + b) / 3f / 255f;
              bf.putFloat(normalizePixelValue);
          }
          return bf;
      }
      void close(){
          interpreter.close();
      }
  }
}

0a2653c851af460fa595bd959398a8f1.png


DEMO


Demo源码下载


相关


TensorFlow Lite(Jinpeng)

TensorFlow Lite 指南

TensorFlow Lite 示例应用

Kotlin入门(4)声明与操作数组

Kotlin Array

google / tensorflow / tensorflow-lite


相关文章
|
3月前
|
XML API Android开发
码农之重学安卓:利用androidx.preference 快速创建一、二级设置菜单(demo)
本文介绍了如何使用androidx.preference库快速创建具有一级和二级菜单的Android设置界面的步骤和示例代码。
120 1
码农之重学安卓:利用androidx.preference 快速创建一、二级设置菜单(demo)
|
2月前
|
Java Maven 开发工具
第一个安卓项目 | 中国象棋demo学习
本文是作者关于其第一个安卓项目——中国象棋demo的学习记录,展示了demo的运行结果、爬坑记录以及参考资料,包括解决Android Studio和maven相关问题的方法。
第一个安卓项目 | 中国象棋demo学习
|
1月前
|
机器学习/深度学习 TensorFlow API
使用 TensorFlow 和 Keras 构建图像分类器
【10月更文挑战第2天】使用 TensorFlow 和 Keras 构建图像分类器
|
2月前
|
Java Android开发 C++
🚀Android NDK开发实战!Java与C++混合编程,打造极致性能体验!📊
在Android应用开发中,追求卓越性能是不变的主题。本文介绍如何利用Android NDK(Native Development Kit)结合Java与C++进行混合编程,提升应用性能。从环境搭建到JNI接口设计,再到实战示例,全面展示NDK的优势与应用技巧,助你打造高性能应用。通过具体案例,如计算斐波那契数列,详细讲解Java与C++的协作流程,帮助开发者掌握NDK开发精髓,实现高效计算与硬件交互。
135 1
|
3月前
|
存储 搜索推荐 Java
探索安卓开发中的自定义视图:打造个性化UI组件Java中的异常处理:从基础到高级
【8月更文挑战第29天】在安卓应用的海洋中,一个独特的用户界面(UI)能让应用脱颖而出。自定义视图是实现这一目标的强大工具。本文将通过一个简单的自定义计数器视图示例,展示如何从零开始创建一个具有独特风格和功能的安卓UI组件,并讨论在此过程中涉及的设计原则、性能优化和兼容性问题。准备好让你的应用与众不同了吗?让我们开始吧!
|
3月前
|
Java 调度 Android开发
Android经典实战之Kotlin的delay函数和Java中的Thread.sleep有什么不同?
本文介绍了 Kotlin 中的 `delay` 函数与 Java 中 `Thread.sleep` 方法的区别。两者均可暂停代码执行,但 `delay` 适用于协程,非阻塞且高效;`Thread.sleep` 则阻塞当前线程。理解这些差异有助于提高程序效率与可读性。
75 1
|
3月前
|
Java Android开发
解决Android编译报错:Unable to make field private final java.lang.String java.io.File.path accessible
解决Android编译报错:Unable to make field private final java.lang.String java.io.File.path accessible
481 1
|
3月前
|
Android开发
Cannot create android app from an archive...containing both DEX and Java-bytecode content
Cannot create android app from an archive...containing both DEX and Java-bytecode content
35 2
|
3月前
|
开发者 算法 虚拟化
惊爆!Uno Platform 调试与性能分析终极攻略,从工具运用到代码优化,带你攻克开发难题成就完美应用
【8月更文挑战第31天】在 Uno Platform 中,调试可通过 Visual Studio 设置断点和逐步执行代码实现,同时浏览器开发者工具有助于 Web 版本调试。性能分析则利用 Visual Studio 的性能分析器检查 CPU 和内存使用情况,还可通过记录时间戳进行简单分析。优化性能涉及代码逻辑优化、资源管理和用户界面简化,综合利用平台提供的工具和技术,确保应用高效稳定运行。
83 0
|
3月前
|
安全 Apache 数据安全/隐私保护
你的Wicket应用安全吗?揭秘在Apache Wicket中实现坚不可摧的安全认证策略
【8月更文挑战第31天】在当前的网络环境中,安全性是任何应用程序的关键考量。Apache Wicket 是一个强大的 Java Web 框架,提供了丰富的工具和组件,帮助开发者构建安全的 Web 应用程序。本文介绍了如何在 Wicket 中实现安全认证,
44 0