利用OpenCV的绘图功能与TensorFlow的模型来识别手写数字。
完整代码GitHub上:跳转按钮.
1.加载数据
data = pd.read_csv('train.csv') x = data.loc[:,data.columns != 'label'].values.astype(np.float32) y = data['label'].values x = x / 255.0 data = tf.data.Dataset.from_tensor_slices((x,y)) data_loader = data.repeat().shuffle(5000).batch(128).prefetch(1)
2.构造模型
class network(tf.keras.Model): def __init__(self): super(network, self).__init__() self.conv1 = tf.keras.layers.Conv2D(32,kernel_size=5,activation=tf.nn.relu) self.maxpool1 =tf.keras.layers.MaxPool2D(2,strides=2) self.conv2 = tf.keras.layers.Conv2D(64,kernel_size=3,activation=tf.nn.relu) self.maxpool2 = tf.keras.layers.MaxPool2D(2,strides=2) self.flatten = tf.keras.layers.Flatten() self.fc1 = tf.keras.layers.Dense(1024) self.dropout = tf.keras.layers.Dropout(rate=0.5) self.out = tf.keras.layers.Dense(10) def call(self,x,is_training=False): x = tf.reshape(x,[-1,28,28,1]) x = self.conv1(x) x = self.maxpool1(x) x = self.conv2(x) x = self.maxpool2(x) x = self.flatten(x) x = self.fc1(x) x = self.dropout(x) x = self.out(x) if not is_training: x = tf.nn.softmax(x) return x conv = network()
3.定义损失函数和精度函数
def cross_entropy_loss(x,y): y = tf.cast(y,tf.int64) loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=x) return tf.reduce_mean(loss) def accuracy(y_pred,y_true): correct_pred = tf.equal(tf.argmax(y_pred,1),tf.cast(y_true,tf.int64)) return tf.reduce_mean(tf.cast(correct_pred,tf.float32))
4.构造优化函数以及训练代码
optimizer = tf.optimizers.SGD(0.01) def run_optimizer(x,y): with tf.GradientTape() as g: pred = conv(x,is_training=True) loss = cross_entropy_loss(pred,y) training_variable = conv.trainable_variables gradient = g.gradient(loss,training_variable) optimizer.apply_gradients(zip(gradient,training_variable)) for i,(x_batch,y_batch) in enumerate(data_loader.take(1000),1): run_optimizer(x_batch,y_batch) if i%50==0: pred = conv(x_batch) acc = accuracy(pred,y_batch) print("%f"%(acc))
5.保存模型
conv.save_weights('mannul_checkpoint')
6.加载模型
conv.load_weights('mannul_checkpoint')
7.鼠标绘图
drawing = False # 是否开始画图 mode = True # True:画矩形,False:画圆 start = (-1, -1) def mouse_event(event, x, y, flags, param): global start, drawing, mode # 左键按下:开始画图 if event == cv2.EVENT_LBUTTONDOWN: drawing = True start = (x, y) # 鼠标移动,画图 elif event == cv2.EVENT_MOUSEMOVE: if drawing: cv2.circle(img, (x, y), 8, (random.randint(185,255)), -1) # 左键释放:结束画图 elif event == cv2.EVENT_LBUTTONUP: drawing = False
8.将绘制的图片送入TensorFlow模型中进行识别,并在图像中显示类别
img = np.zeros((512, 512, 1), np.uint8) img[:,:] = 100 cv2.namedWindow('image') cv2.setMouseCallback('image', mouse_event) while(True): temp = img.copy() a = temp[:400,:400] cv2.imshow('a',cv2.resize(a,(28,28))) temp = temp.astype(np.float32) temp = cv2.resize(temp[:400,:400],(28,28)) temp = np.reshape(temp,[1,784]) pred = conv(temp, is_training=False) b = np.argmax(pred.numpy(),axis=1) cv2.imshow('image', img) # 按下m切换模式 if cv2.waitKey(1) == ord('a'): img = np.zeros((512, 512, 1), np.uint8) cv2.putText(img, str(b), (250, 450), cv2.FONT_HERSHEY_COMPLEX, 2.0, (100, 200, 200), 5) elif cv2.waitKey(1) == 27: break
公众号:FPGA之旅