在mnist_predict目录下新建文件,命名为predict_pic.py,识别图像。
import os
import cv2
import numpy as np
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
将输入的彩色图像转换为二值化图
def color_input(endimg):
img_gray = cv2.cvtColor(endimg, cv2.COLOR_BGR2GRAY) # 灰度化
ret, img_threshold = cv2.threshold(img_gray, 127, 255, cv2.THRESH_BINARY_INV)
return img_threshold
读取图像并显示
def read_pic(path):
img = cv2.imread(path, cv2.IMREAD_COLOR)
cv2.imshow('img', img)
cv2.waitKey(0)
img_threshold = color_input(img)
cv2.imshow('img_threshold', img_threshold)
cv2.waitKey(0)
return img_threshold
if name == 'main':
with tf.Session() as sess:
saver = tf.train.import_meta_graph('model_data/model.meta')
# 模型恢复
saver.restore(sess, 'model_data/model')
graph = tf.get_default_graph()
# 获取变量
input_x = sess.graph.get_tensor_by_name("Mul:0")
y_conv2 = sess.graph.get_tensor_by_name("final_result:0")
# 读取图像
img_threshold = read_pic("nine.png")
# 将图像进行缩放
im = cv2.resize(img_threshold, (28, 28), interpolation=cv2.INTER_CUBIC)
x_img = np.reshape(im, [-1, 784])
# 识别
output = sess.run(y_conv2, feed_dict={input_x: x_img})
result = np.argmax(output)
print("识别结果为:{}".format(result))