Tensorflow MobileNet移植到Android
最近看到一个巨牛的人工智能教程,分享一下给大家。教程不仅是零基础,通俗易懂,而且非常风趣幽默,像看小说一样!觉得太牛了,所以分享给大家。平时碎片时间可以当小说看,【点这里可以去膜拜一下大神的“小说”】。
1 CKPT模型转换pb文件
使用上一篇博客《MobileNet V1官方预训练模型的使用》中下载的MobileNet V1官方预训练的模型《MobileNet_v1_1.0_192》。虽然打包下载的文件中包含已经转换过的pb文件,但是官方提供的pb模型输出是1001类别对应的概率,我们需要的是概率最大的3类。可在原始网络中使用函数tf.nn.top_k获取概率最大的3类,将函数tf.nn.top_k作为网络中的一个计算节点。模型转换代码如下所示。
import tensorflow as tf from mobilenet_v1 import mobilenet_v1,mobilenet_v1_arg_scope import numpy as np slim = tf.contrib.slim CKPT = 'mobilenet_v1_1.0_192.ckpt' def build_model(inputs): with slim.arg_scope(mobilenet_v1_arg_scope(is_training=False)): logits, end_points = mobilenet_v1(inputs, is_training=False, depth_multiplier=1.0, num_classes=1001) scores = end_points['Predictions'] print(scores) #取概率最大的3个类别及其对应概率 output = tf.nn.top_k(scores, k=3, sorted=True) #indices为类别索引,values为概率值 return output.indices,output.values def load_model(sess): loader = tf.train.Saver() loader.restore(sess,CKPT) inputs=tf.placeholder(dtype=tf.float32,shape=(1,192,192,3),name='input') classes_tf,scores_tf = build_model(inputs) classes = tf.identity(classes_tf, name='classes') scores = tf.identity(scores_tf, name='scores') with tf.Session() as sess: load_model(sess) graph = tf.get_default_graph() output_graph_def = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [classes.op.name,scores.op.name]) tf.train.write_graph(output_graph_def, 'model', 'mobilenet_v1_1.0_192.pb', as_text=False)
上面代码中,单一的所有类别概率经过计算节点tf.nn.top_k后分为两个输出:概率最大的3个类别classes,概率最大的3个类别的概率scores。执行上面代码后,在目录“model”中得到文件mobilenet_v1_1.0_192.pb。
2 移植到Android中
2.1 AndroidStudio中使用Tensorflow Mobile
首先,AndroidStudio版本必须是3.0及以上。创建Android Project后,在Module:app的build.gradle文件中的dependencies中加入如下:
compile 'org.tensorflow:tensorflow-android:+'
2.2 Tensorflow Mobile接口
使用Tensorflow Mobile库中模型调用封装类org.tensorflow.contrib.android.TensorFlowInferenceInterface完成模型的调用,主要使用的如下函数。
public TensorFlowInferenceInterface(AssetManager assetManager, String model){...} public void feed(String inputName, float[] src, long... dims) {...} public void run(String[] outputNames) {...} public void fetch(String outputName, int[] dst) {...}
其中,构造函数中的参数model表示目录“assets”中模型名称。feed函数中参数inputName表示输入节点的名称,即对应模型转换时指定输入节点的名称“input”,参数src表示输入数据数组,变长参数dims表示输入的维度,如传入1,192,192,3则表示输入数据的Shape=[1,192,192,3]。函数run的参数outputNames表示执行从输入节点到outputNames中节点的所有路径。函数fetch中参数outputName表示输出节点的名称,将指定的输出节点的数据拷贝到dst中。
2.3 Bitmap对象转float[]
注意到,在2.1小节中函数feed传入到输入节点的数据对象是float[]。因此有必要将Bitmap转为float[]对象,示例代码如下所示。
//读取Bitmap像素值,并放入到浮点数数组中。归一化到[-1,1] private float[] getFloatImage(Bitmap bitmap){ Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH); bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight()); for (int i = 0; i < inputIntData.length; ++i) { final int val = inputIntData[i]; inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2; inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2; inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ; } return inputFloatData; }
由于MobileNet V1预训练的模型输入数据归一化到[-1,1],因此在函数getFloatImage中转换数据的同时将数据归一化到[-1,1]。
2.4 封装模型调用
为了便于调用,将与模型相关的调用函数封装到类TFModelUtils中,通过TFModelUtils的run函数完成模型的调用,示例代码如下所示。
package com.huachao.mn_v1_192; import android.content.res.AssetManager; import android.graphics.Bitmap; import android.graphics.Matrix; import android.util.Log; import org.tensorflow.contrib.android.TensorFlowInferenceInterface; import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.util.HashMap; import java.util.Map; public class TFModelUtils { private TensorFlowInferenceInterface inferenceInterface; private int[] inputIntData ; private float[] inputFloatData ; private int inputWH; private String inputName; private String[] outputNames; private Map<Integer,String> dict; public TFModelUtils(AssetManager assetMngr,int inputWH,String inputName,String[]outputNames,String modelName){ this.inputWH=inputWH; this.inputName=inputName; this.outputNames=outputNames; this.inputIntData=new int[inputWH*inputWH]; this.inputFloatData = new float[inputWH*inputWH*3]; //从assets目录加载模型 inferenceInterface= new TensorFlowInferenceInterface(assetMngr, modelName); this.loadLabel(assetMngr); } public Map<String,Object> run(Bitmap bitmap){ float[] inputData = getFloatImage(bitmap); //将输入数据复制到TensorFlow中,指定输入Shape=[1,INPUT_WH,INPUT_WH,3] inferenceInterface.feed(inputName, inputData, 1, inputWH, inputWH, 3); // 执行模型 inferenceInterface.run( outputNames ); //将输出Tensor对象复制到指定数组中 int[] classes=new int[3]; float[] scores=new float[3]; inferenceInterface.fetch(outputNames[0], classes); inferenceInterface.fetch(outputNames[1], scores); Map<String,Object> results=new HashMap<>(); results.put("scores",scores); String[] classesLabel = new String[3]; for(int i =0;i<3;i++){ int idx=classes[i]; classesLabel[i]=dict.get(idx); // System.out.printf("classes:"+dict.get(idx)+",scores:"+scores[i]+"\n"); } results.put("classes",classesLabel); return results; } //读取Bitmap像素值,并放入到浮点数数组中。归一化到[-1,1] private float[] getFloatImage(Bitmap bitmap){ Bitmap bm = getResizedBitmap(bitmap,inputWH,inputWH); bm.getPixels(inputIntData, 0, bm.getWidth(), 0, 0, bm.getWidth(), bm.getHeight()); for (int i = 0; i < inputIntData.length; ++i) { final int val = inputIntData[i]; inputFloatData[i * 3 + 0] =(float) (((val >> 16) & 0xFF)/255.0-0.5)*2; inputFloatData[i * 3 + 1] = (float)(((val >> 8) & 0xFF)/255.0-0.5)*2; inputFloatData[i * 3 + 2] = (float)(( val & 0xFF)/255.0-0.5)*2 ; } return inputFloatData; } //对图像做Resize public Bitmap getResizedBitmap(Bitmap bm, int newWidth, int newHeight) { int width = bm.getWidth(); int height = bm.getHeight(); float scaleWidth = ((float) newWidth) / width; float scaleHeight = ((float) newHeight) / height; Matrix matrix = new Matrix(); matrix.postScale(scaleWidth, scaleHeight); Bitmap resizedBitmap = Bitmap.createBitmap( bm, 0, 0, width, height, matrix, false); bm.recycle(); return resizedBitmap; } private void loadLabel( AssetManager assetManager ) { dict=new HashMap<>(); try { InputStream stream = assetManager.open("label.txt"); InputStreamReader isr=new InputStreamReader(stream); BufferedReader br=new BufferedReader(isr); String line; while((line=br.readLine())!=null){ line=line.trim(); String[] arr = line.split(","); if(arr.length!=2) continue; int key=Integer.parseInt(arr[0]); String value = arr[1]; dict.put(key,value); } }catch (Exception e){ e.printStackTrace(); Log.e("ERROR",e.getMessage()); } } }
3 模型测试