AI计算机视觉笔记十二:基于 LeNet5 的手写数字识别及训练

本文涉及的产品
云解析 DNS,旗舰版 1个月
全局流量管理 GTM,标准版 1个月
公共DNS(含HTTPDNS解析),每月1000万次HTTP解析
简介: 本文档介绍了如何使用PyTorch框架复现经典的LeNet5模型,并通过MNIST数据集进行训练与测试。首先,创建虚拟环境并安装所需库,接着下载MNIST数据集。训练部分涉及四个主要文件:`LeNet5.py`、`myDatast.py`、`readMnist.py` 和 `train.py`。通过这些文件搭建模型并完成训练过程。最后,通过测试脚本验证模型准确性,结果显示准确率达到0.986,满足预期需求。文档还提供了详细的环境配置和代码实现细节。

一、介绍

pytorch复现lenet5模型,并检测自己手写的数字图片。

利用torch框架搭建模型相对比较简单,但是也会遇到很多问题,网上资料很多,搭建模型的方法大同小异,在我尝试了自己搭建搭建出来模型,无论是训练还是检测都会遇到很多的问题,像这种自己遇到的问题,请教别人也没有用。原本使用的是github上的一份代码来复现,环境搭建完成后,才发现要有GPU,而我搭建是使用CPU,失败告终,为了复现,租用了AutoDL平台,在次搭建,这里记录GPU下的操作,CPU版本需要修改源码,自行修改,我的目的是在要训练自己的模型并在RK3568上部署,所以先训练并测试好。为后续部署作基础。

二、环境

image.png

三、搭建

1、创建虚拟环境

conda create -n LeNet5_env python==3.8

2、安装pytorch

Previous PyTorch Versions | PyTorch

根据官方PyTorch,安装pytorch,使用的是CPU版本,其他版本自行安装,安装命令:

pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
 -i https://pypi.tuna.tsinghua.edu.cn/simple

还需要安装一些其他的库

pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple

3、数据集下载

http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

直接把上面地址复制到网页上,就只可以下载

下载后保存到data/MNIST/raw目录下

image.png

四、训练代码

训练模型有四个文件分别为:LeNet5.py;myDatast.py;readMnist.py;train.py

文件LeNet5.py是网络层模型

train.py

import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data import DataLoader
from readMnist import *
from myDatast import Mnist
from LeNet5 import LeNet5

train_images = load_train_images()
train_labels = load_train_labels()

trainData = Mnist(train_images, train_labels)
train_data = DataLoader(dataset=trainData, batch_size=1, shuffle=True)

lenet5 = LeNet5()
lenet5.cuda()

lossFun = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(params=lenet5.parameters(), lr=1e-4)

Epochs = 100
L = len(train_data)

for epoch in range(Epochs):
    for i, (img, id) in enumerate(train_data):

        img = img.float()
        id = id.float()

        img = img.cuda()
        id = id.cuda()

        img = Variable(img, requires_grad=True)
        id = Variable(id, requires_grad=True)

        Output = lenet5.forward(img)
        loss = lossFun(Output, id.long())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter = epoch * L + i + 1
        if iter % 100 == 0:
            print('epoch:{},iter:{},loss:{:.6f}'.format(epoch + 1, iter, loss))

    torch.save(lenet5.state_dict(), 'lenet5.pth')

LeNet5.py

import torch.nn as nn


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(in_features=16 * 4 * 4, out_features=120),
            nn.Sigmoid()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Sigmoid()
        )

        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, img):
        img = self.conv1.forward(img)
        img = self.conv2.forward(img)

        img = img.view(img.size()[0], -1)

        img = self.fc1.forward(img)
        img = self.fc2.forward(img)
        img = self.fc3.forward(img)

        return img

readMnist.py

from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np


class Mnist(Dataset):
    def __init__(self, dataset, label):
        self.dataset = dataset
        self.label = label
        self.len = len(self.label)
        self.transforms = transforms.Compose([transforms.ToTensor() , transforms.Normalize(mean=[0.5], std=[0.5])])

    def __len__(self):
        return self.len

    def __getitem__(self, item):
        img = self.dataset[item]
        img_id = self.label[item]

        img = np.transpose(img,(1,2,0))
        img = self.transforms(img)

        return img, img_id

readMnist.py

import numpy as np
import struct
import matplotlib.pyplot as plt
import cv2

fpath = 'G:/enpei_Project_Code/21_LeNet5/LeNet5-master/myLeNet5/data/MNIST/raw/'

# 训练集文件
train_images_idx3_ubyte_file = fpath + 'train-images-idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = fpath + 'train-labels-idx1-ubyte'

# 测试集文件
test_images_idx3_ubyte_file = fpath + 't10k-images-idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = fpath + 't10k-labels-idx1-ubyte'


def decode_idx3_ubyte(idx3_ubyte_file):
    """
    解析idx3文件的通用函数
    :param idx3_ubyte_file: idx3文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx3_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
    offset = 0
    fmt_header = '>iiii'  # 因为数据结构中前4行的数据类型都是32位整型,所以采用i格式,但我们需要读取前4行数据,所以需要4个i。我们后面会看到标签集中,只使用2个ii。
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))

    # 解析数据集
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)  # 获得数据在缓存中的指针位置,从前面介绍的数据结构可以看出,读取了前4行之后,指针位置(即偏移位置offset)指向0016。
    print(offset)
    fmt_image = '>' + str(
        image_size) + 'B'  # 图像数据像素值的类型为unsigned char型,对应的format格式为B。这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)
    print(fmt_image, offset, struct.calcsize(fmt_image))
    images = np.empty((num_images, 1, num_rows, num_cols))
    # plt.figure()
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('已解析 %d' % (i + 1) + '张')
            print(offset)
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((1, num_rows, num_cols))
        # print(images[i])
        offset += struct.calcsize(fmt_image)
    #        plt.imshow(images[i],'gray')
    #        plt.pause(0.00001)
    #        plt.show()
    # plt.show()

    return images


def decode_idx1_ubyte(idx1_ubyte_file):
    """
    解析idx1文件的通用函数
    :param idx1_ubyte_file: idx1文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx1_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数和标签数
    offset = 0
    fmt_header = '>ii'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('已解析 %d' % (i + 1) + '张')
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels


def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
    """
    TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  60000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
    """
    TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  60000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.
    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)


def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
    """
    TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  10000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).
    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
    """
    TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  10000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.
    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)


if __name__ == '__main__':

    train_images = load_train_images()
    train_labels = load_train_labels()
    test_images = load_test_images()
    test_labels = load_test_labels()

    pass

    # 查看前十个数据及其标签以读取是否正确
    for i in range(10):
        print(train_labels[i])

        img = train_images[i]
        img = np.transpose(img, (1, 2, 0))

        cv2.namedWindow('img')
        cv2.imshow('img', img)
        cv2.waitKey(100)

    print('done')

上面代码需要注意的是数据集的路径,需要修改成对应的路径。
image.png

运行python train.py

image.png

五、测试

from LeNet5 import LeNet5
import torch
from readMnist import *
from myDatast import Mnist
from torch.utils.data import DataLoader
import numpy as np
import cv2

test_images = load_test_images()
test_labels = load_test_labels()

testData = Mnist(test_images, test_labels)
test_data = DataLoader(dataset=testData, batch_size=1, shuffle=True)

lenet5 = LeNet5()
lenet5.load_state_dict(torch.load('lenet5.pth'))
lenet5.eval()

showimg = True
js = 0
for i, (img, id) in enumerate(test_data):

    img = img.float()
    outid = lenet5(img)

    oid = torch.argmax(outid)
    if oid == id:
        js = js + 1

    if showimg == True:
        img = img.numpy()
        img = np.squeeze(img)

        id = id.numpy()
        id = np.squeeze(id)
        id = np.int32(id)

        oid = oid.numpy()
        oid = np.squeeze(oid)

        maxv = np.max(img)
        minv = np.min(img)

        img = (img - minv) / (maxv - minv)

        cv2.namedWindow("img", 0)
        cv2.imshow("img", img)

        title = "img, predicted value:{},truth value:{}".format(oid, id)
        cv2.setWindowTitle("img",title)

        cv2.waitKey(1)

print('准确率:{:.6f}'.format(js / (i + 1)))

测试结果准确率达到0.986基本达到要求

image.png

相关文章
|
2月前
|
人工智能 测试技术 API
AI计算机视觉笔记二十 九:yolov10竹签模型,自动数竹签
本文介绍了如何在AutoDL平台上搭建YOLOv10环境并进行竹签检测与计数。首先从官网下载YOLOv10源码并创建虚拟环境,安装依赖库。接着通过官方模型测试环境是否正常工作。然后下载自定义数据集并配置`mycoco128.yaml`文件,使用`yolo detect train`命令或Python代码进行训练。最后,通过命令行或API调用测试训练结果,并展示竹签计数功能。如需转载,请注明原文出处。
|
25天前
|
Python 机器学习/深度学习 人工智能
手把手教你从零开始构建并训练你的第一个强化学习智能体:深入浅出Agent项目实战,带你体验编程与AI结合的乐趣
【10月更文挑战第1天】本文通过构建一个简单的强化学习环境,演示了如何创建和训练智能体以完成特定任务。我们使用Python、OpenAI Gym和PyTorch搭建了一个基础的智能体,使其学会在CartPole-v1环境中保持杆子不倒。文中详细介绍了环境设置、神经网络构建及训练过程。此实战案例有助于理解智能体的工作原理及基本训练方法,为更复杂应用奠定基础。首先需安装必要库: ```bash pip install gym torch ``` 接着定义环境并与之交互,实现智能体的训练。通过多个回合的试错学习,智能体逐步优化其策略。这一过程虽从基础做起,但为后续研究提供了良好起点。
81 4
手把手教你从零开始构建并训练你的第一个强化学习智能体:深入浅出Agent项目实战,带你体验编程与AI结合的乐趣
|
2月前
|
机器学习/深度学习 人工智能 PyTorch
AI计算机视觉笔记三十二:LPRNet车牌识别
LPRNet是一种基于Pytorch的高性能、轻量级车牌识别框架,适用于中国及其他国家的车牌识别。该网络无需对字符进行预分割,采用端到端的轻量化设计,结合了squeezenet和inception的思想。其创新点在于去除了RNN,仅使用CNN与CTC Loss,并通过特定的卷积模块提取上下文信息。环境配置包括使用CPU开发板和Autodl训练环境。训练和测试过程需搭建虚拟环境并安装相关依赖,执行训练和测试脚本时可能遇到若干错误,需相应调整代码以确保正确运行。使用官方模型可获得较高的识别准确率,自行训练时建议增加训练轮数以提升效果。
|
2月前
|
人工智能 开发工具 计算机视觉
AI计算机视觉笔记三十:yolov8_obb旋转框训练
本文介绍了如何使用AUTODL环境搭建YOLOv8-obb的训练流程。首先创建虚拟环境并激活,然后通过指定清华源安装ultralytics库。接着下载YOLOv8源码,并使用指定命令开始训练,过程中可能会下载yolov8n.pt文件。训练完成后,可使用相应命令进行预测测试。
|
2月前
|
人工智能 并行计算 测试技术
AI计算机视觉笔记三十一:基于UNetMultiLane的多车道线等识别
该项目基于开源数据集 VIL100 实现了 UNetMultiLane,用于多车道线及车道线类型的识别。数据集中标注了六个车道的车道线及其类型。项目详细记录了从环境搭建到模型训练与测试的全过程,并提供了在 CPU 上进行训练和 ONNX 转换的代码示例。训练过程约需 4 小时完成 50 个 epoch。此外,还实现了视频检测功能,可在视频中实时识别车道线及其类型。
|
机器学习/深度学习 人工智能 开发工具
打造AI训练基础平台!Unity推出Machine Learning Agents
但在未来,人工智能游戏选手或许将会面临新的对手:另一个人工智能。今天,全球最大的3D游戏引擎Unity宣布发布Unity Machine Learning Agents,通过将其游戏引擎与TensorFlow等机器学习框架相连接
1675 0
|
3天前
|
机器学习/深度学习 人工智能 供应链
AI技术在医疗领域的应用与未来展望###
本文深入探讨了人工智能(AI)技术在医疗领域的多种应用及其带来的革命性变化,从疾病诊断、治疗方案优化到患者管理等方面进行了详细阐述。通过具体案例和数据分析,展示了AI如何提高医疗服务效率、降低成本并改善患者体验。同时,文章也讨论了AI技术在医疗领域面临的挑战和未来发展趋势,为行业从业者和研究人员提供参考。 ###
|
3天前
|
机器学习/深度学习 人工智能 算法
AI技术在医疗领域的应用与挑战
【10月更文挑战第21天】 本文探讨了人工智能(AI)在医疗领域的多种应用,包括疾病诊断、治疗方案推荐、药物研发和患者管理等。通过分析这些应用案例,我们可以看到AI技术如何提高医疗服务的效率和准确性。然而,AI在医疗领域的广泛应用也面临诸多挑战,如数据隐私保护、算法透明度和伦理问题。本文旨在为读者提供一个全面的视角,了解AI技术在医疗领域的潜力和面临的困难。
|
3天前
|
机器学习/深度学习 人工智能 搜索推荐
AI在医疗健康领域的应用与前景
随着科技的不断进步,人工智能(AI)技术已经深入到我们生活的方方面面,特别是在医疗健康领域。本文将探讨AI在医疗健康领域的应用现状、面临的挑战以及未来的发展前景。
|
4天前
|
人工智能 自然语言处理 监控
AI技术在文本情感分析中的应用
【10月更文挑战第22天】本文将探讨人工智能(AI)如何改变我们对文本情感分析的理解和应用。我们将通过实际的代码示例,深入了解AI如何帮助我们识别和理解文本中的情感。无论你是AI新手还是有经验的开发者,这篇文章都将为你提供有价值的信息。让我们一起探索AI的奇妙世界吧!
12 3

热门文章

最新文章