我与黑色素瘤细调CNN Mobilenet癌症数据集产生黑色素瘤和Non-Melanoma类分类模型,然后我转换Tensorflow Lite模式,但是当我开发它在移动,然后单击“分类”按钮程序崩溃告诉我不幸InceptionTutorial已经停止”。我的python调优代码如下:https://www.kaggle.com/gabrielmv/melanoma-classifie-mobilenet/notebook
from keras.models import Model
from keras.layers import Dense, Dropout
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping
from keras.applications.mobilenet import MobileNet, preprocess_input
import numpy as np
import itertools
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
def plot_training_curves(history):
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'g', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.figure()
plt.plot(epochs, acc, 'r', label='Training acc')
plt.plot(epochs, val_acc, 'g', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.figure()
plt.show()
def print_results(cm):
tp = cm[0, 0]
tn = cm[1, 1]
fn = cm[0, 1]
fp = cm[1, 0]
accuracy = ((tp + tn) / (tp + tn + fp + fn)) * 100
sensitivity = (tp / (tp + fn)) * 100
specificity = (tn / ( tn + fp )) * 100
print ('Accuracy: ', accuracy)
print ('Sensitivity: ', sensitivity)
print ('Specificity: ', specificity)
def fine_tune_mobile_net(train_batches, train_steps, class_weights, valid_batches, val_steps, file_path):
mobile = MobileNet() # mobile.summary()
x = mobile.layers[-6].output
x = Dropout(0.5)(x)
predictions = Dense(2, activation='softmax')(x)
model = Model(inputs = mobile.input, outputs = predictions)
for layer in mobile.layers:
layer.trainable = False
model.compile(Adam(lr = 0.003),
loss = 'categorical_crossentropy',
metrics = ['accuracy'])
model.fit_generator(train_batches,
steps_per_epoch = train_steps,
class_weight = class_weights,
validation_data = valid_batches,
validation_steps = val_steps,
epochs = 50,
verbose = 1,
callbacks = callbacks)
model.load_weights(file_path)
print ('*** Fine Tunning MobileNet ***')
for layer in model.layers[:-23]:
layer.trainable = False
model.compile(Adam(lr = 0.003),
loss = 'categorical_crossentropy',
metrics = ['accuracy'])
history = model.fit_generator(train_batches,
steps_per_epoch = train_steps,
class_weight = class_weights,
validation_data = valid_batches,
validation_steps = val_steps,
epochs = 75,
verbose = 1,
callbacks = callbacks)
return model, history
def save_model(model, file_path):
model_json = model.to_json()
with open('model.json', 'w') as json_file:
json_file.write(model_json)
file_path = 'weights-mobilenet-2.0.h5'
callbacks = [
ModelCheckpoint(file_path, monitor = 'val_acc', verbose = 1, save_best_only = True, mode = 'max'),
ReduceLROnPlateau(monitor = 'val_loss', factor = 0.2, patience = 8, verbose = 1, mode = 'min', min_lr = 0.00001),
EarlyStopping(monitor = 'val_loss', min_delta = 1e-10, patience = 15, verbose = 1)
]
training_path = '../input/dermmel/DermMel/train_sep'
validation_path = '../input/dermmel/DermMel/valid'
test_path = '../input/dermmel/DermMel/test'
num_train_samples = 10682
num_val_samples = 3562
num_test_samples = 3561
train_batch_size = 16
val_batch_size = 16
test_batch_size = 16
train_steps = np.ceil(num_train_samples / train_batch_size)
val_steps = np.ceil(num_val_samples / val_batch_size)
test_steps = np.ceil(num_val_samples / val_batch_size)
class_weights = {
0: 4.1, # melanoma
1: 1.0 # non-melanoma
}
train_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(training_path,
target_size = (224, 224),
batch_size = val_batch_size,
class_mode = 'categorical')
valid_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(validation_path,
target_size = (224, 224),
batch_size = val_batch_size,
class_mode = 'categorical')
test_batches = ImageDataGenerator(rescale = 1./255).flow_from_directory(test_path,
target_size = (224, 224),
batch_size = test_batch_size,
class_mode = 'categorical',
shuffle = False)
model, history = fine_tune_mobile_net(train_batches, train_steps, class_weights, valid_batches, val_steps, file_path)
save_model(model, file_path)
model.load_weights(file_path)
test_labels = test_batches.classes
predictions = model.predict_generator(test_batches, steps = val_steps, verbose = 1)
cm = confusion_matrix(test_labels, predictions.argmax(axis=1))
然后使用之前用于多对象分类的Android源代码,从(https://github.com/soum-io/TensorFlowLiteInceptionTutorial)中定制我的黑色素瘤分类应用程序(https://github.com/soum-io/TensorFlowLiteInceptionTutorial)进行两个主要活动 ChooseModel.java
package com.soumio.inceptiontutorial;
import android.Manifest;
import android.content.ContentValues;
import android.content.Intent;
import android.content.pm.ActivityInfo;
import android.content.pm.PackageManager;
import android.net.Uri;
import android.os.Build;
import android.provider.MediaStore;
import android.support.annotation.NonNull;
import android.support.v4.app.ActivityCompat;
import android.support.v4.content.ContextCompat;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.Toast;
import com.soundcloud.android.crop.Crop;
import java.io.File;
public class ChooseModel extends AppCompatActivity {
// button for each available classifier
private Button inceptionFloat;
private Button inceptionQuant;
// for permission requests
public static final int REQUEST_PERMISSION = 300;
// request code for permission requests to the os for image
public static final int REQUEST_IMAGE = 100;
// will hold uri of image obtained from camera
private Uri imageUri;
// string to send to next activity that describes the chosen classifier
private String chosen;
//boolean value dictating if chosen model is quantized version or not.
private boolean quant;
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_choose_model);
// request permission to use the camera on the user's phone
if (ActivityCompat.checkSelfPermission(this.getApplicationContext(), android.Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED){
ActivityCompat.requestPermissions(this, new String[] {android.Manifest.permission.CAMERA}, REQUEST_PERMISSION);
}
// request permission to write data (aka images) to the user's external storage of their phone
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M
&& ContextCompat.checkSelfPermission(this, Manifest.permission.WRITE_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.WRITE_EXTERNAL_STORAGE},
REQUEST_PERMISSION);
}
// request permission to read data (aka images) from the user's external storage of their phone
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M
&& ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.READ_EXTERNAL_STORAGE},
REQUEST_PERMISSION);
}
// on click for inception float model
inceptionFloat = (Button)findViewById(R.id.inception_float);
inceptionFloat.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
// filename in assets
chosen = "inception_float.tflite";
// model in not quantized
quant = false;
// open camera
openCameraIntent();
}
});
// on click for inception quant model
inceptionQuant = (Button)findViewById(R.id.inception_quant);
inceptionQuant.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
// filename in assets
chosen = "inception_quant.tflite";
// model in not quantized
quant = true;
// open camera
openCameraIntent();
}
});
}
// opens camera for user
private void openCameraIntent(){
ContentValues values = new ContentValues();
values.put(MediaStore.Images.Media.TITLE, "New Picture");
values.put(MediaStore.Images.Media.DESCRIPTION, "From your Camera");
// tell camera where to store the resulting picture
imageUri = getContentResolver().insert(
MediaStore.Images.Media.EXTERNAL_CONTENT_URI, values);
Intent intent = new Intent(MediaStore.ACTION_IMAGE_CAPTURE);
intent.putExtra(MediaStore.EXTRA_OUTPUT, imageUri);
setRequestedOrientation(ActivityInfo.SCREEN_ORIENTATION_PORTRAIT);
// start camera, and wait for it to finish
startActivityForResult(intent, REQUEST_IMAGE);
}
// checks that the user has allowed all the required permission of read and write and camera. If not, notify the user and close the application
@Override
public void onRequestPermissionsResult(final int requestCode, @NonNull final String[] permissions, @NonNull final int[] grantResults) {
super.onRequestPermissionsResult(requestCode, permissions, grantResults);
if (requestCode == REQUEST_PERMISSION) {
if (!(grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED)) {
Toast.makeText(getApplicationContext(),"This application needs read, write, and camera permissions to run. Application now closing.",Toast.LENGTH_LONG);
System.exit(0);
}
}
}
// dictates what to do after the user takes an image, selects and image, or crops an image
@Override
protected void onActivityResult(int requestCode, int resultCode, Intent data){
super.onActivityResult(requestCode, resultCode, data);
// if the camera activity is finished, obtained the uri, crop it to make it square, and send it to 'Classify' activity
if(requestCode == REQUEST_IMAGE && resultCode == RESULT_OK) {
try {
Uri source_uri = imageUri;
Uri dest_uri = Uri.fromFile(new File(getCacheDir(), "cropped"));
// need to crop it to square image as CNN's always required square input
Crop.of(source_uri, dest_uri).asSquare().start(ChooseModel.this);
} catch (Exception e) {
e.printStackTrace();
}
}
// if cropping acitivty is finished, get the resulting cropped image uri and send it to 'Classify' activity
else if(requestCode == Crop.REQUEST_CROP && resultCode == RESULT_OK){
imageUri = Crop.getOutput(data);
Intent i = new Intent(ChooseModel.this, Classify.class);
// put image data in extras to send
i.putExtra("resID_uri", imageUri);
// put filename in extras
i.putExtra("chosen", chosen);
// put model type in extras
i.putExtra("quant", quant);
// send other required data
startActivity(i);
}
}
}
Classify.java
package com.soumio.inceptiontutorial;
import android.content.Intent;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.graphics.Matrix;
import android.graphics.drawable.BitmapDrawable;
import android.net.Uri;
import android.os.SystemClock;
import android.provider.MediaStore;
import android.support.v7.app.AppCompatActivity;
import android.os.Bundle;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.TextView;
import org.tensorflow.lite.Interpreter;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
public class Classify extends AppCompatActivity {
// presets for rgb conversion
private static final int RESULTS_TO_SHOW = 3;
private static final int IMAGE_MEAN = 128;
private static final float IMAGE_STD = 128.0f;
// options for model interpreter
private final Interpreter.Options tfliteOptions = new Interpreter.Options();
// tflite graph
private Interpreter tflite;
// holds all the possible labels for model
private List<String> labelList;
// holds the selected image data as bytes
private ByteBuffer imgData = null;
// holds the probabilities of each label for non-quantized graphs
private float[][] labelProbArray = null;
// holds the probabilities of each label for quantized graphs
private byte[][] labelProbArrayB = null;
// array that holds the labels with the highest probabilities
private String[] topLables = null;
// array that holds the highest probabilities
private String[] topConfidence = null;
// selected classifier information received from extras
private String chosen;
private boolean quant;
// input image dimensions for the Inception Model
private int DIM_IMG_SIZE_X = 244;
private int DIM_IMG_SIZE_Y = 244;
private int DIM_PIXEL_SIZE = 3;
// int array to hold image data
private int[] intValues;
// activity elements
private ImageView selected_image;
private Button classify_button;
private Button back_button;
private TextView label1;
private TextView label2;
private TextView label3;
private TextView Confidence1;
private TextView Confidence2;
private TextView Confidence3;
// priority queue that will hold the top results from the CNN
private PriorityQueue<Map.Entry<String, Float>> sortedLabels =
new PriorityQueue<>(
RESULTS_TO_SHOW,
new Comparator<Map.Entry<String, Float>>() {
@Override
public int compare(Map.Entry<String, Float> o1, Map.Entry<String, Float> o2) {
return (o1.getValue()).compareTo(o2.getValue());
}
});
@Override
protected void onCreate(Bundle savedInstanceState) {
// get all selected classifier data from classifiers
chosen = (String) getIntent().getStringExtra("chosen");
quant = (boolean) getIntent().getBooleanExtra("quant", false);
// initialize array that holds image data
intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y];
super.onCreate(savedInstanceState);
//initilize graph and labels
try{
tflite = new Interpreter(loadModelFile(), tfliteOptions);
labelList = loadLabelList();
} catch (Exception ex){
ex.printStackTrace();
}
// initialize byte array. The size depends if the input data needs to be quantized or not
if(quant){
imgData =
ByteBuffer.allocateDirect(
DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
} else {
imgData =
ByteBuffer.allocateDirect(
4 * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE);
}
imgData.order(ByteOrder.nativeOrder());
// initialize probabilities array. The datatypes that array holds depends if the input data needs to be quantized or not
if(quant){
labelProbArrayB= new byte[1][labelList.size()];
} else {
labelProbArray = new float[1][labelList.size()];
}
setContentView(R.layout.activity_classify);
// labels that hold top three results of CNN
label1 = (TextView) findViewById(R.id.label1);
label2 = (TextView) findViewById(R.id.label2);
label3 = (TextView) findViewById(R.id.label3);
// displays the probabilities of top labels
Confidence1 = (TextView) findViewById(R.id.Confidence1);
Confidence2 = (TextView) findViewById(R.id.Confidence2);
Confidence3 = (TextView) findViewById(R.id.Confidence3);
// initialize imageView that displays selected image to the user
selected_image = (ImageView) findViewById(R.id.selected_image);
// initialize array to hold top labels
topLables = new String[RESULTS_TO_SHOW];
// initialize array to hold top probabilities
topConfidence = new String[RESULTS_TO_SHOW];
// allows user to go back to activity to select a different image
back_button = (Button)findViewById(R.id.back_button);
back_button.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
Intent i = new Intent(Classify.this, ChooseModel.class);
startActivity(i);
}
});
// classify current dispalyed image
classify_button = (Button)findViewById(R.id.classify_image);
classify_button.setOnClickListener(new View.OnClickListener() {
@Override
public void onClick(View view) {
// get current bitmap from imageView
Bitmap bitmap_orig = ((BitmapDrawable)selected_image.getDrawable()).getBitmap();
// resize the bitmap to the required input size to the CNN
Bitmap bitmap = getResizedBitmap(bitmap_orig, DIM_IMG_SIZE_X, DIM_IMG_SIZE_Y);
// convert bitmap to byte array
convertBitmapToByteBuffer(bitmap);
// pass byte data to the graph
if(quant){
tflite.run(imgData, labelProbArrayB);
} else {
tflite.run(imgData, labelProbArray);
}
// display the results
printTopKLabels();
}
});
// get image from previous activity to show in the imageView
Uri uri = (Uri)getIntent().getParcelableExtra("resID_uri");
try {
Bitmap bitmap = MediaStore.Images.Media.getBitmap(getContentResolver(), uri);
selected_image.setImageBitmap(bitmap);
// not sure why this happens, but without this the image appears on its side
selected_image.setRotation(selected_image.getRotation() + 90);
} catch (IOException e) {
e.printStackTrace();
}
}
// loads tflite grapg from file
private MappedByteBuffer loadModelFile() throws IOException {
AssetFileDescriptor fileDescriptor = this.getAssets().openFd(chosen);
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
// converts bitmap to byte array which is passed in the tflite graph
private void convertBitmapToByteBuffer(Bitmap bitmap) {
if (imgData == null) {
return;
}
imgData.rewind();
bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
// loop through all pixels
int pixel = 0;
for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
final int val = intValues[pixel++];
// get rgb values from intValues where each int holds the rgb values for a pixel.
// if quantized, convert each rgb value to a byte, otherwise to a float
if(quant){
imgData.put((byte) ((val >> 16) & 0xFF));
imgData.put((byte) ((val >> 8) & 0xFF));
imgData.put((byte) (val & 0xFF));
} else {
imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
imgData.putFloat((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
}
}
}
}
// loads the labels from the label txt file in assets into a string array
private List<String> loadLabelList() throws IOException {
List<String> labelList = new ArrayList<String>();
BufferedReader reader =
new BufferedReader(new InputStreamReader(this.getAssets().open("labels.txt")));
String line;
while ((line = reader.readLine()) != null) {
labelList.add(line);
}
reader.close();
return labelList;
}
// print the top labels and respective confidences
private void printTopKLabels() {
// add all results to priority queue
for (int i = 0; i < labelList.size(); ++i) {
if(quant){
sortedLabels.add(
new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArrayB[0][i] & 0xff) / 255.0f));
} else {
sortedLabels.add(
new AbstractMap.SimpleEntry<>(labelList.get(i), labelProbArray[0][i]));
}
if (sortedLabels.size() > RESULTS_TO_SHOW) {
sortedLabels.poll();
}
}
// get top results from priority queue
final int size = sortedLabels.size();
for (int i = 0; i < size; ++i) {
Map.Entry<String, Float> label = sortedLabels.poll();
topLables[i] = label.getKey();
topConfidence[i] = String.format("%.0f%%",label.getValue()*100);
}
// set the corresponding textviews with the results
label1.setText("1. "+topLables[2]);
label2.setText("2. "+topLables[1]);
// label3.setText("3. "+topLables[0]);
Confidence1.setText(topConfidence[2]);
Confidence2.setText(topConfidence[1]);
// Confidence3.setText(topConfidence[0]);
}
// resizes bitmap to given dimensions
public Bitmap getResizedBitmap(Bitmap bm, int newWidth, int newHeight) {
int width = bm.getWidth();
int height = bm.getHeight();
float scaleWidth = ((float) newWidth) / width;
float scaleHeight = ((float) newHeight) / height;
Matrix matrix = new Matrix();
matrix.postScale(scaleWidth, scaleHeight);
Bitmap resizedBitmap = Bitmap.createBitmap(
bm, 0, 0, width, height, matrix, false);
return resizedBitmap;
}
}
有谁能帮我解决这个问题,谢谢! 问题来源StackOverflow 地址:/questions/59465930/android-classification-app-crashes-with-tensorflow-lite-model
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。