AI计算机视觉笔记十二:基于 LeNet5 的手写数字识别及训练

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 本文档介绍了如何使用PyTorch框架复现经典的LeNet5模型,并通过MNIST数据集进行训练与测试。首先,创建虚拟环境并安装所需库,接着下载MNIST数据集。训练部分涉及四个主要文件:`LeNet5.py`、`myDatast.py`、`readMnist.py` 和 `train.py`。通过这些文件搭建模型并完成训练过程。最后,通过测试脚本验证模型准确性,结果显示准确率达到0.986,满足预期需求。文档还提供了详细的环境配置和代码实现细节。

一、介绍

pytorch复现lenet5模型,并检测自己手写的数字图片。

利用torch框架搭建模型相对比较简单,但是也会遇到很多问题,网上资料很多,搭建模型的方法大同小异,在我尝试了自己搭建搭建出来模型,无论是训练还是检测都会遇到很多的问题,像这种自己遇到的问题,请教别人也没有用。原本使用的是github上的一份代码来复现,环境搭建完成后,才发现要有GPU,而我搭建是使用CPU,失败告终,为了复现,租用了AutoDL平台,在次搭建,这里记录GPU下的操作,CPU版本需要修改源码,自行修改,我的目的是在要训练自己的模型并在RK3568上部署,所以先训练并测试好。为后续部署作基础。

二、环境

image.png

三、搭建

1、创建虚拟环境

conda create -n LeNet5_env python==3.8

2、安装pytorch

Previous PyTorch Versions | PyTorch

根据官方PyTorch,安装pytorch,使用的是CPU版本,其他版本自行安装,安装命令:

pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
 -i https://pypi.tuna.tsinghua.edu.cn/simple

还需要安装一些其他的库

pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple

3、数据集下载

http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

直接把上面地址复制到网页上,就只可以下载

下载后保存到data/MNIST/raw目录下

image.png

四、训练代码

训练模型有四个文件分别为:LeNet5.py;myDatast.py;readMnist.py;train.py

文件LeNet5.py是网络层模型

train.py

import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data import DataLoader
from readMnist import *
from myDatast import Mnist
from LeNet5 import LeNet5

train_images = load_train_images()
train_labels = load_train_labels()

trainData = Mnist(train_images, train_labels)
train_data = DataLoader(dataset=trainData, batch_size=1, shuffle=True)

lenet5 = LeNet5()
lenet5.cuda()

lossFun = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(params=lenet5.parameters(), lr=1e-4)

Epochs = 100
L = len(train_data)

for epoch in range(Epochs):
    for i, (img, id) in enumerate(train_data):

        img = img.float()
        id = id.float()

        img = img.cuda()
        id = id.cuda()

        img = Variable(img, requires_grad=True)
        id = Variable(id, requires_grad=True)

        Output = lenet5.forward(img)
        loss = lossFun(Output, id.long())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter = epoch * L + i + 1
        if iter % 100 == 0:
            print('epoch:{},iter:{},loss:{:.6f}'.format(epoch + 1, iter, loss))

    torch.save(lenet5.state_dict(), 'lenet5.pth')

LeNet5.py

import torch.nn as nn


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(in_features=16 * 4 * 4, out_features=120),
            nn.Sigmoid()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Sigmoid()
        )

        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, img):
        img = self.conv1.forward(img)
        img = self.conv2.forward(img)

        img = img.view(img.size()[0], -1)

        img = self.fc1.forward(img)
        img = self.fc2.forward(img)
        img = self.fc3.forward(img)

        return img

readMnist.py

from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np


class Mnist(Dataset):
    def __init__(self, dataset, label):
        self.dataset = dataset
        self.label = label
        self.len = len(self.label)
        self.transforms = transforms.Compose([transforms.ToTensor() , transforms.Normalize(mean=[0.5], std=[0.5])])

    def __len__(self):
        return self.len

    def __getitem__(self, item):
        img = self.dataset[item]
        img_id = self.label[item]

        img = np.transpose(img,(1,2,0))
        img = self.transforms(img)

        return img, img_id

readMnist.py

import numpy as np
import struct
import matplotlib.pyplot as plt
import cv2

fpath = 'G:/enpei_Project_Code/21_LeNet5/LeNet5-master/myLeNet5/data/MNIST/raw/'

# 训练集文件
train_images_idx3_ubyte_file = fpath + 'train-images-idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = fpath + 'train-labels-idx1-ubyte'

# 测试集文件
test_images_idx3_ubyte_file = fpath + 't10k-images-idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = fpath + 't10k-labels-idx1-ubyte'


def decode_idx3_ubyte(idx3_ubyte_file):
    """
    解析idx3文件的通用函数
    :param idx3_ubyte_file: idx3文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx3_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
    offset = 0
    fmt_header = '>iiii'  # 因为数据结构中前4行的数据类型都是32位整型,所以采用i格式,但我们需要读取前4行数据,所以需要4个i。我们后面会看到标签集中,只使用2个ii。
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))

    # 解析数据集
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)  # 获得数据在缓存中的指针位置,从前面介绍的数据结构可以看出,读取了前4行之后,指针位置(即偏移位置offset)指向0016。
    print(offset)
    fmt_image = '>' + str(
        image_size) + 'B'  # 图像数据像素值的类型为unsigned char型,对应的format格式为B。这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)
    print(fmt_image, offset, struct.calcsize(fmt_image))
    images = np.empty((num_images, 1, num_rows, num_cols))
    # plt.figure()
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('已解析 %d' % (i + 1) + '张')
            print(offset)
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((1, num_rows, num_cols))
        # print(images[i])
        offset += struct.calcsize(fmt_image)
    #        plt.imshow(images[i],'gray')
    #        plt.pause(0.00001)
    #        plt.show()
    # plt.show()

    return images


def decode_idx1_ubyte(idx1_ubyte_file):
    """
    解析idx1文件的通用函数
    :param idx1_ubyte_file: idx1文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx1_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数和标签数
    offset = 0
    fmt_header = '>ii'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('已解析 %d' % (i + 1) + '张')
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels


def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
    """
    TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  60000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
    """
    TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  60000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.
    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)


def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
    """
    TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  10000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
    """
    TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  10000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.
    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)


if __name__ == '__main__':

    train_images = load_train_images()
    train_labels = load_train_labels()
    test_images = load_test_images()
    test_labels = load_test_labels()

    pass

    # 查看前十个数据及其标签以读取是否正确
    for i in range(10):
        print(train_labels[i])

        img = train_images[i]
        img = np.transpose(img, (1, 2, 0))

        cv2.namedWindow('img')
        cv2.imshow('img', img)
        cv2.waitKey(100)

    print('done')

上面代码需要注意的是数据集的路径,需要修改成对应的路径。
image.png

运行python train.py

image.png

五、测试

from LeNet5 import LeNet5
import torch
from readMnist import *
from myDatast import Mnist
from torch.utils.data import DataLoader
import numpy as np
import cv2

test_images = load_test_images()
test_labels = load_test_labels()

testData = Mnist(test_images, test_labels)
test_data = DataLoader(dataset=testData, batch_size=1, shuffle=True)

lenet5 = LeNet5()
lenet5.load_state_dict(torch.load('lenet5.pth'))
lenet5.eval()

showimg = True
js = 0
for i, (img, id) in enumerate(test_data):

    img = img.float()
    outid = lenet5(img)

    oid = torch.argmax(outid)
    if oid == id:
        js = js + 1

    if showimg == True:
        img = img.numpy()
        img = np.squeeze(img)

        id = id.numpy()
        id = np.squeeze(id)
        id = np.int32(id)

        oid = oid.numpy()
        oid = np.squeeze(oid)

        maxv = np.max(img)
        minv = np.min(img)

        img = (img - minv) / (maxv - minv)

        cv2.namedWindow("img", 0)
        cv2.imshow("img", img)

        title = "img, predicted value:{},truth value:{}".format(oid, id)
        cv2.setWindowTitle("img",title)

        cv2.waitKey(1)

print('准确率:{:.6f}'.format(js / (i + 1)))

测试结果准确率达到0.986基本达到要求

image.png

相关文章
|
6天前
|
人工智能 测试技术 API
AI计算机视觉笔记二十 九:yolov10竹签模型,自动数竹签
本文介绍了如何在AutoDL平台上搭建YOLOv10环境并进行竹签检测与计数。首先从官网下载YOLOv10源码并创建虚拟环境,安装依赖库。接着通过官方模型测试环境是否正常工作。然后下载自定义数据集并配置`mycoco128.yaml`文件,使用`yolo detect train`命令或Python代码进行训练。最后,通过命令行或API调用测试训练结果,并展示竹签计数功能。如需转载,请注明原文出处。
|
6天前
|
机器学习/深度学习 人工智能 PyTorch
AI计算机视觉笔记三十二:LPRNet车牌识别
LPRNet是一种基于Pytorch的高性能、轻量级车牌识别框架,适用于中国及其他国家的车牌识别。该网络无需对字符进行预分割,采用端到端的轻量化设计,结合了squeezenet和inception的思想。其创新点在于去除了RNN,仅使用CNN与CTC Loss,并通过特定的卷积模块提取上下文信息。环境配置包括使用CPU开发板和Autodl训练环境。训练和测试过程需搭建虚拟环境并安装相关依赖,执行训练和测试脚本时可能遇到若干错误,需相应调整代码以确保正确运行。使用官方模型可获得较高的识别准确率,自行训练时建议增加训练轮数以提升效果。
|
6天前
|
人工智能 开发工具 计算机视觉
AI计算机视觉笔记三十:yolov8_obb旋转框训练
本文介绍了如何使用AUTODL环境搭建YOLOv8-obb的训练流程。首先创建虚拟环境并激活,然后通过指定清华源安装ultralytics库。接着下载YOLOv8源码,并使用指定命令开始训练,过程中可能会下载yolov8n.pt文件。训练完成后,可使用相应命令进行预测测试。
|
6天前
|
人工智能 并行计算 测试技术
AI计算机视觉笔记三十一:基于UNetMultiLane的多车道线等识别
该项目基于开源数据集 VIL100 实现了 UNetMultiLane,用于多车道线及车道线类型的识别。数据集中标注了六个车道的车道线及其类型。项目详细记录了从环境搭建到模型训练与测试的全过程,并提供了在 CPU 上进行训练和 ONNX 转换的代码示例。训练过程约需 4 小时完成 50 个 epoch。此外,还实现了视频检测功能,可在视频中实时识别车道线及其类型。
|
6天前
|
人工智能 监控 算法
AI计算机视觉笔记二十 八:基于YOLOv8实例分割的DeepSORT多目标跟踪
本文介绍了YOLOv8实例分割与DeepSORT视觉跟踪算法的结合应用,通过YOLOv8进行目标检测分割,并利用DeepSORT实现特征跟踪,在复杂环境中保持目标跟踪的准确性与稳定性。该技术广泛应用于安全监控、无人驾驶等领域。文章提供了环境搭建、代码下载及测试步骤,并附有详细代码示例。
|
4月前
|
机器学习/深度学习 计算机视觉
AIGC核心技术——计算机视觉(CV)预训练大模型
【1月更文挑战第13天】AIGC核心技术——计算机视觉(CV)预训练大模型
539 3
AIGC核心技术——计算机视觉(CV)预训练大模型
|
9月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Azure 机器学习 - 使用 ONNX 对来自 AutoML 的计算机视觉模型进行预测
Azure 机器学习 - 使用 ONNX 对来自 AutoML 的计算机视觉模型进行预测
107 0
|
6天前
|
人工智能 测试技术 PyTorch
AI计算机视觉笔记二十四:YOLOP 训练+测试+模型评估
本文介绍了通过正点原子的ATK-3568了解并实现YOLOP(You Only Look Once for Panoptic Driving Perception)的过程,包括训练、测试、转换为ONNX格式及在ONNX Runtime上的部署。YOLOP由华中科技大学团队于2021年发布,可在Jetson TX2上达到23FPS,实现了目标检测、可行驶区域分割和车道线检测的多任务学习。文章详细记录了环境搭建、训练数据准备、模型转换和测试等步骤,并解决了ONNX转换过程中的问题。
|
2月前
|
自然语言处理 监控 自动驾驶
大模型在自然语言处理(NLP)、计算机视觉(CV)和多模态模型等领域应用最广
【7月更文挑战第26天】大模型在自然语言处理(NLP)、计算机视觉(CV)和多模态模型等领域应用最广
60 11
|
3月前
|
编解码 机器人 测试技术
2024年6月计算机视觉论文推荐:扩散模型、视觉语言模型、视频生成等
6月还有一周就要结束了,我们今天来总结2024年6月上半月发表的最重要的论文,重点介绍了计算机视觉领域的最新研究和进展。
108 8