TensorFlow结合OpenCV实现手写数字识别

简介: 笔记

利用OpenCV的绘图功能与TensorFlow的模型来识别手写数字。


完整代码GitHub上:跳转按钮.

1.加载数据

data = pd.read_csv('train.csv')
x = data.loc[:,data.columns != 'label'].values.astype(np.float32)
y = data['label'].values
x = x / 255.0
data = tf.data.Dataset.from_tensor_slices((x,y))
data_loader = data.repeat().shuffle(5000).batch(128).prefetch(1)

2.构造模型

class network(tf.keras.Model):
    def __init__(self):
        super(network, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32,kernel_size=5,activation=tf.nn.relu)
        self.maxpool1 =tf.keras.layers.MaxPool2D(2,strides=2)
        self.conv2 = tf.keras.layers.Conv2D(64,kernel_size=3,activation=tf.nn.relu)
        self.maxpool2 = tf.keras.layers.MaxPool2D(2,strides=2)
        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(1024)
        self.dropout = tf.keras.layers.Dropout(rate=0.5)
        self.out = tf.keras.layers.Dense(10)
    def call(self,x,is_training=False):
        x = tf.reshape(x,[-1,28,28,1])
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.out(x)
        if not is_training:
            x = tf.nn.softmax(x)
        return x
conv = network()

3.定义损失函数和精度函数

def cross_entropy_loss(x,y):
    y = tf.cast(y,tf.int64)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=x)
    return tf.reduce_mean(loss)
def accuracy(y_pred,y_true):
    correct_pred = tf.equal(tf.argmax(y_pred,1),tf.cast(y_true,tf.int64))
    return tf.reduce_mean(tf.cast(correct_pred,tf.float32))

4.构造优化函数以及训练代码

optimizer = tf.optimizers.SGD(0.01)
def run_optimizer(x,y):
    with tf.GradientTape() as g:
        pred = conv(x,is_training=True)
        loss = cross_entropy_loss(pred,y)
    training_variable = conv.trainable_variables
    gradient = g.gradient(loss,training_variable)
    optimizer.apply_gradients(zip(gradient,training_variable))
for i,(x_batch,y_batch) in enumerate(data_loader.take(1000),1):
    run_optimizer(x_batch,y_batch)
    if i%50==0:
        pred = conv(x_batch)
        acc = accuracy(pred,y_batch)
        print("%f"%(acc))

5.保存模型

conv.save_weights('mannul_checkpoint')


6.加载模型

conv.load_weights('mannul_checkpoint')


7.鼠标绘图

drawing = False  # 是否开始画图
mode = True  # True:画矩形,False:画圆
start = (-1, -1)
def mouse_event(event, x, y, flags, param):
    global start, drawing, mode
    # 左键按下:开始画图
    if event == cv2.EVENT_LBUTTONDOWN:
        drawing = True
        start = (x, y)
    # 鼠标移动,画图
    elif event == cv2.EVENT_MOUSEMOVE:
        if drawing:
            cv2.circle(img, (x, y), 8, (random.randint(185,255)), -1)
    # 左键释放:结束画图
    elif event == cv2.EVENT_LBUTTONUP:
        drawing = False

8.将绘制的图片送入TensorFlow模型中进行识别,并在图像中显示类别

img = np.zeros((512, 512, 1), np.uint8)
img[:,:] = 100
cv2.namedWindow('image')
cv2.setMouseCallback('image', mouse_event)
while(True):
    temp = img.copy()
    a = temp[:400,:400]
    cv2.imshow('a',cv2.resize(a,(28,28)))
    temp = temp.astype(np.float32)
    temp = cv2.resize(temp[:400,:400],(28,28))
    temp = np.reshape(temp,[1,784])
    pred = conv(temp, is_training=False)
    b = np.argmax(pred.numpy(),axis=1)
    cv2.imshow('image', img)
    # 按下m切换模式
    if cv2.waitKey(1) == ord('a'):
        img = np.zeros((512, 512, 1), np.uint8)
        cv2.putText(img, str(b), (250, 450), cv2.FONT_HERSHEY_COMPLEX, 2.0, (100, 200, 200), 5)
    elif cv2.waitKey(1) == 27:
        break

12.png

公众号:FPGA之旅



目录
相关文章
|
6月前
|
机器学习/深度学习 算法 算法框架/工具
深度学习实战:基于TensorFlow与OpenCV的手语识别系统
深度学习实战:基于TensorFlow与OpenCV的手语识别系统
425 0
|
6月前
|
机器学习/深度学习 算法 TensorFlow
【Python深度学习】Tensorflow对半环形数据分类、手写数字识别、猫狗识别实战(附源码)
【Python深度学习】Tensorflow对半环形数据分类、手写数字识别、猫狗识别实战(附源码)
124 0
|
1月前
|
并行计算 PyTorch TensorFlow
Ubuntu安装笔记(一):安装显卡驱动、cuda/cudnn、Anaconda、Pytorch、Tensorflow、Opencv、Visdom、FFMPEG、卸载一些不必要的预装软件
这篇文章是关于如何在Ubuntu操作系统上安装显卡驱动、CUDA、CUDNN、Anaconda、PyTorch、TensorFlow、OpenCV、FFMPEG以及卸载不必要的预装软件的详细指南。
3364 3
|
1月前
|
PyTorch TensorFlow 算法框架/工具
Jetson环境安装(一):Ubuntu18.04安装pytorch、opencv、onnx、tensorflow、setuptools、pycuda....
本文提供了在Ubuntu 18.04操作系统的NVIDIA Jetson平台上安装深度学习和计算机视觉相关库的详细步骤,包括PyTorch、OpenCV、ONNX、TensorFlow等。
44 1
Jetson环境安装(一):Ubuntu18.04安装pytorch、opencv、onnx、tensorflow、setuptools、pycuda....
|
6月前
|
机器学习/深度学习 TensorFlow 算法框架/工具
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
【2月更文挑战第22天】本文介绍基于Python的tensorflow库,将tensorflow与keras训练好的SavedModel格式神经网络模型转换为frozen graph格式,从而可以用OpenCV库在C++等其他语言中将其打开的方法~
146 1
OpenCV读取tensorflow 2.X模型的方法:将SavedModel转为frozen graph
|
6月前
|
机器学习/深度学习 PyTorch TensorFlow
TensorFlow、PyTorch、Keras、Scikit-learn和ChatGPT。视觉开发软件工具 Halcon、VisionPro、LabView、OpenCV
TensorFlow、PyTorch、Keras、Scikit-learn和ChatGPT。视觉开发软件工具 Halcon、VisionPro、LabView、OpenCV
108 1
|
机器学习/深度学习 人工智能 文字识别
OpenCV-字典法实现数字识别(尺寸归一化+图像差值)
OpenCV-字典法实现数字识别(尺寸归一化+图像差值)
106 0
|
机器学习/深度学习 数据可视化 PyTorch
手把手教你使用LabVIEW OpenCV DNN实现手写数字识别(含源码)
今天和大家一起来看一下在LabVIEW中如何使用OpenCV DNN模块实现手写数字识别
216 0
|
TensorFlow 算法框架/工具 计算机视觉
|
机器学习/深度学习 PyTorch 测试技术