AI计算机视觉笔记十二:基于 LeNet5 的手写数字识别及训练

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 本文档介绍了如何使用PyTorch框架复现经典的LeNet5模型,并通过MNIST数据集进行训练与测试。首先,创建虚拟环境并安装所需库,接着下载MNIST数据集。训练部分涉及四个主要文件:`LeNet5.py`、`myDatast.py`、`readMnist.py` 和 `train.py`。通过这些文件搭建模型并完成训练过程。最后,通过测试脚本验证模型准确性,结果显示准确率达到0.986,满足预期需求。文档还提供了详细的环境配置和代码实现细节。

一、介绍

pytorch复现lenet5模型,并检测自己手写的数字图片。

利用torch框架搭建模型相对比较简单,但是也会遇到很多问题,网上资料很多,搭建模型的方法大同小异,在我尝试了自己搭建搭建出来模型,无论是训练还是检测都会遇到很多的问题,像这种自己遇到的问题,请教别人也没有用。原本使用的是github上的一份代码来复现,环境搭建完成后,才发现要有GPU,而我搭建是使用CPU,失败告终,为了复现,租用了AutoDL平台,在次搭建,这里记录GPU下的操作,CPU版本需要修改源码,自行修改,我的目的是在要训练自己的模型并在RK3568上部署,所以先训练并测试好。为后续部署作基础。

二、环境

image.png

三、搭建

1、创建虚拟环境

conda create -n LeNet5_env python==3.8

2、安装pytorch

Previous PyTorch Versions | PyTorch

根据官方PyTorch,安装pytorch,使用的是CPU版本,其他版本自行安装,安装命令:

pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
 -i https://pypi.tuna.tsinghua.edu.cn/simple

还需要安装一些其他的库

pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple

3、数据集下载

http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

直接把上面地址复制到网页上,就只可以下载

下载后保存到data/MNIST/raw目录下

image.png

四、训练代码

训练模型有四个文件分别为:LeNet5.py;myDatast.py;readMnist.py;train.py

文件LeNet5.py是网络层模型

train.py

import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data import DataLoader
from readMnist import *
from myDatast import Mnist
from LeNet5 import LeNet5

train_images = load_train_images()
train_labels = load_train_labels()

trainData = Mnist(train_images, train_labels)
train_data = DataLoader(dataset=trainData, batch_size=1, shuffle=True)

lenet5 = LeNet5()
lenet5.cuda()

lossFun = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(params=lenet5.parameters(), lr=1e-4)

Epochs = 100
L = len(train_data)

for epoch in range(Epochs):
    for i, (img, id) in enumerate(train_data):

        img = img.float()
        id = id.float()

        img = img.cuda()
        id = id.cuda()

        img = Variable(img, requires_grad=True)
        id = Variable(id, requires_grad=True)

        Output = lenet5.forward(img)
        loss = lossFun(Output, id.long())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter = epoch * L + i + 1
        if iter % 100 == 0:
            print('epoch:{},iter:{},loss:{:.6f}'.format(epoch + 1, iter, loss))

    torch.save(lenet5.state_dict(), 'lenet5.pth')

LeNet5.py

import torch.nn as nn


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(in_features=16 * 4 * 4, out_features=120),
            nn.Sigmoid()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Sigmoid()
        )

        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, img):
        img = self.conv1.forward(img)
        img = self.conv2.forward(img)

        img = img.view(img.size()[0], -1)

        img = self.fc1.forward(img)
        img = self.fc2.forward(img)
        img = self.fc3.forward(img)

        return img

readMnist.py

from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np


class Mnist(Dataset):
    def __init__(self, dataset, label):
        self.dataset = dataset
        self.label = label
        self.len = len(self.label)
        self.transforms = transforms.Compose([transforms.ToTensor() , transforms.Normalize(mean=[0.5], std=[0.5])])

    def __len__(self):
        return self.len

    def __getitem__(self, item):
        img = self.dataset[item]
        img_id = self.label[item]

        img = np.transpose(img,(1,2,0))
        img = self.transforms(img)

        return img, img_id

readMnist.py

import numpy as np
import struct
import matplotlib.pyplot as plt
import cv2

fpath = 'G:/enpei_Project_Code/21_LeNet5/LeNet5-master/myLeNet5/data/MNIST/raw/'

# 训练集文件
train_images_idx3_ubyte_file = fpath + 'train-images-idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = fpath + 'train-labels-idx1-ubyte'

# 测试集文件
test_images_idx3_ubyte_file = fpath + 't10k-images-idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = fpath + 't10k-labels-idx1-ubyte'


def decode_idx3_ubyte(idx3_ubyte_file):
    """
    解析idx3文件的通用函数
    :param idx3_ubyte_file: idx3文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx3_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
    offset = 0
    fmt_header = '>iiii'  # 因为数据结构中前4行的数据类型都是32位整型,所以采用i格式,但我们需要读取前4行数据,所以需要4个i。我们后面会看到标签集中,只使用2个ii。
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))

    # 解析数据集
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)  # 获得数据在缓存中的指针位置,从前面介绍的数据结构可以看出,读取了前4行之后,指针位置(即偏移位置offset)指向0016。
    print(offset)
    fmt_image = '>' + str(
        image_size) + 'B'  # 图像数据像素值的类型为unsigned char型,对应的format格式为B。这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)
    print(fmt_image, offset, struct.calcsize(fmt_image))
    images = np.empty((num_images, 1, num_rows, num_cols))
    # plt.figure()
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('已解析 %d' % (i + 1) + '张')
            print(offset)
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((1, num_rows, num_cols))
        # print(images[i])
        offset += struct.calcsize(fmt_image)
    #        plt.imshow(images[i],'gray')
    #        plt.pause(0.00001)
    #        plt.show()
    # plt.show()

    return images


def decode_idx1_ubyte(idx1_ubyte_file):
    """
    解析idx1文件的通用函数
    :param idx1_ubyte_file: idx1文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx1_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数和标签数
    offset = 0
    fmt_header = '>ii'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('已解析 %d' % (i + 1) + '张')
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels


def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
    """
    TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  60000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
    """
    TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  60000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.
    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)


def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
    """
    TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  10000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
    """
    TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  10000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.
    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)


if __name__ == '__main__':

    train_images = load_train_images()
    train_labels = load_train_labels()
    test_images = load_test_images()
    test_labels = load_test_labels()

    pass

    # 查看前十个数据及其标签以读取是否正确
    for i in range(10):
        print(train_labels[i])

        img = train_images[i]
        img = np.transpose(img, (1, 2, 0))

        cv2.namedWindow('img')
        cv2.imshow('img', img)
        cv2.waitKey(100)

    print('done')

上面代码需要注意的是数据集的路径,需要修改成对应的路径。
image.png

运行python train.py

image.png

五、测试

from LeNet5 import LeNet5
import torch
from readMnist import *
from myDatast import Mnist
from torch.utils.data import DataLoader
import numpy as np
import cv2

test_images = load_test_images()
test_labels = load_test_labels()

testData = Mnist(test_images, test_labels)
test_data = DataLoader(dataset=testData, batch_size=1, shuffle=True)

lenet5 = LeNet5()
lenet5.load_state_dict(torch.load('lenet5.pth'))
lenet5.eval()

showimg = True
js = 0
for i, (img, id) in enumerate(test_data):

    img = img.float()
    outid = lenet5(img)

    oid = torch.argmax(outid)
    if oid == id:
        js = js + 1

    if showimg == True:
        img = img.numpy()
        img = np.squeeze(img)

        id = id.numpy()
        id = np.squeeze(id)
        id = np.int32(id)

        oid = oid.numpy()
        oid = np.squeeze(oid)

        maxv = np.max(img)
        minv = np.min(img)

        img = (img - minv) / (maxv - minv)

        cv2.namedWindow("img", 0)
        cv2.imshow("img", img)

        title = "img, predicted value:{},truth value:{}".format(oid, id)
        cv2.setWindowTitle("img",title)

        cv2.waitKey(1)

print('准确率:{:.6f}'.format(js / (i + 1)))

测试结果准确率达到0.986基本达到要求

image.png

相关文章
|
2月前
|
人工智能 测试技术 API
AI计算机视觉笔记二十 九:yolov10竹签模型,自动数竹签
本文介绍了如何在AutoDL平台上搭建YOLOv10环境并进行竹签检测与计数。首先从官网下载YOLOv10源码并创建虚拟环境,安装依赖库。接着通过官方模型测试环境是否正常工作。然后下载自定义数据集并配置`mycoco128.yaml`文件,使用`yolo detect train`命令或Python代码进行训练。最后,通过命令行或API调用测试训练结果,并展示竹签计数功能。如需转载,请注明原文出处。
|
1月前
|
Python 机器学习/深度学习 人工智能
手把手教你从零开始构建并训练你的第一个强化学习智能体:深入浅出Agent项目实战,带你体验编程与AI结合的乐趣
【10月更文挑战第1天】本文通过构建一个简单的强化学习环境,演示了如何创建和训练智能体以完成特定任务。我们使用Python、OpenAI Gym和PyTorch搭建了一个基础的智能体,使其学会在CartPole-v1环境中保持杆子不倒。文中详细介绍了环境设置、神经网络构建及训练过程。此实战案例有助于理解智能体的工作原理及基本训练方法,为更复杂应用奠定基础。首先需安装必要库: ```bash pip install gym torch ``` 接着定义环境并与之交互,实现智能体的训练。通过多个回合的试错学习,智能体逐步优化其策略。这一过程虽从基础做起,但为后续研究提供了良好起点。
95 4
手把手教你从零开始构建并训练你的第一个强化学习智能体:深入浅出Agent项目实战,带你体验编程与AI结合的乐趣
|
2月前
|
机器学习/深度学习 人工智能 PyTorch
AI计算机视觉笔记三十二:LPRNet车牌识别
LPRNet是一种基于Pytorch的高性能、轻量级车牌识别框架,适用于中国及其他国家的车牌识别。该网络无需对字符进行预分割,采用端到端的轻量化设计,结合了squeezenet和inception的思想。其创新点在于去除了RNN,仅使用CNN与CTC Loss,并通过特定的卷积模块提取上下文信息。环境配置包括使用CPU开发板和Autodl训练环境。训练和测试过程需搭建虚拟环境并安装相关依赖,执行训练和测试脚本时可能遇到若干错误,需相应调整代码以确保正确运行。使用官方模型可获得较高的识别准确率,自行训练时建议增加训练轮数以提升效果。
|
2月前
|
人工智能 开发工具 计算机视觉
AI计算机视觉笔记三十:yolov8_obb旋转框训练
本文介绍了如何使用AUTODL环境搭建YOLOv8-obb的训练流程。首先创建虚拟环境并激活,然后通过指定清华源安装ultralytics库。接着下载YOLOv8源码,并使用指定命令开始训练,过程中可能会下载yolov8n.pt文件。训练完成后,可使用相应命令进行预测测试。
|
2月前
|
人工智能 并行计算 测试技术
AI计算机视觉笔记三十一:基于UNetMultiLane的多车道线等识别
该项目基于开源数据集 VIL100 实现了 UNetMultiLane,用于多车道线及车道线类型的识别。数据集中标注了六个车道的车道线及其类型。项目详细记录了从环境搭建到模型训练与测试的全过程,并提供了在 CPU 上进行训练和 ONNX 转换的代码示例。训练过程约需 4 小时完成 50 个 epoch。此外,还实现了视频检测功能,可在视频中实时识别车道线及其类型。
|
机器学习/深度学习 人工智能 算法
【计算机视觉】究竟谁能解决可解释性 AI?
准确性与可解释性是不能同时达到的吗?来自 IEEE 研究员 Cuntai Guan 这样认为:“许多机器决策仍然没有得到很好的理解”。大多数论文甚至提出在准确性和可解释性之间进行严格区分。
1321 0
【计算机视觉】究竟谁能解决可解释性 AI?
|
5天前
|
人工智能 JSON API
阿里云文档智能 & RAG解决方案:提升AI大模型业务理解与应用
阿里云推出的文档智能 & RAG解决方案,旨在通过先进的文档解析技术和检索增强生成(RAG)方法,显著提升人工智能大模型在业务场景中的应用效果。该方案通过文档智能(Document Mind)技术将非结构化文档内容转换为结构化数据,提取文档的层级树、样式和版面信息,并输出为Markdown和Json格式,为RAG提供语义分块策略。这一过程不仅解决了文档内容解析错误和切块丢失语义信息的问题,还优化了输出LLM友好的Markdown信息。方案的优势在于其多格式支持能力,能够处理包括Office文档、PDF、Html、图片在内的主流文件类型,返回文档的样式、版面信息和层级树结构。
34 2
|
5天前
|
人工智能 搜索推荐 安全
AI技术在医疗领域的应用与挑战
【10月更文挑战第27天】 本文探讨了人工智能(AI)在医疗领域的应用,包括疾病诊断、药物研发和患者管理等方面。同时,也分析了AI在医疗领域面临的挑战,如数据隐私、伦理问题和技术局限性等。通过对这些方面的深入分析,我们可以更好地理解AI在医疗领域的潜力和发展方向。
104 59
|
1天前
|
存储 XML 人工智能
深度解读AI在数字档案馆中的创新应用:高效识别与智能档案管理
基于OCR技术的纸质档案电子化方案,通过先进的AI能力平台,实现手写、打印、复古文档等多格式高效识别与智能归档。该方案大幅提升了档案管理效率,确保数据安全与隐私,为档案馆提供全面、智能化的电子化管理解决方案。
63 48
|
6天前
|
人工智能 前端开发 Java
基于开源框架Spring AI Alibaba快速构建Java应用
本文旨在帮助开发者快速掌握并应用 Spring AI Alibaba,提升基于 Java 的大模型应用开发效率和安全性。
基于开源框架Spring AI Alibaba快速构建Java应用

热门文章

最新文章