LeNet网络搭建与基本训练流程

简介: LeNet网络搭建与基本训练流程

模型


2de6b682a217a754ea7bbef5c366d493.png


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()   # 解决继承父类中出现的一系列问题
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))       # 输入(3,32,32) 输出(16,28,28)
        x = self.pool1(x)               # 输出(16,14,14)
        x = F.relu(self.conv2(x))       # 输出(32,10,10)
        x = self.pool2(x)               # 输出(32,5,5)
        x = x.view(-1, 32*5*5)          # 输出(32*5*5),batch=-1设为动态调整这个维度上的元素的个数,以保证元素的总数不变
        x = F.relu(self.fc1(x))         # 输出(120)
        x = F.relu(self.fc2(x))         # 输出(84)
        x = self.fc3(x)                 # 输出(10)
        return x


预处理


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


1.转换图片数据为pytorch中的Tensor格式

2.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))将数据转换为标准正态分布,即逐个c h a n n e l channelchannel的对图像进行标准化(均值变为0 00,标准差为1 11),可以加快模型的收敛


加载数据集


# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36, shuffle=True, num_workers=0)
# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000, shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_label = test_data_iter.next()
classes = ('plane', 'car', 'bird', 'cat', 'deer'
           'dog', 'frog', 'horse', 'ship', 'truck')


加载cifar10数据集,然后取出测试集的图片和标签


3c921f8fb15391e781d42960bc342f4b.png


image-20220707211309072.png


训练


1.加载模型、定义损失函数、优化器


net = LeNet()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)


  • nn.CrossEntropyLoss():使用交叉熵损失函数
  • optim.Adam(net.parameters(), lr=0.001):net.parameters()是将net的参数都丢进优化器里
  • net = LeNet():注意LeNet后面有一个()


5cb30f14528e0673052d788ad29f085f.gif


2.训练循环


def train_process():
    for epoch in range(10):
        running_loss = 0.0
        for step, data in enumerate(trainloader, start=0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if step % 500 ==499:
                with torch.no_grad():
                    outputs = net(test_image)
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
                    print('[%d, %5d] train_loss: %.3f test_accuracy:%.3f'
                          %(epoch+1, step+1, running_loss/500, accuracy))
                    running_loss = 0.0
    print("Finished Training")
train_process()
save_path = 'Lenet.pth'
torch.save(net.state_dict(),save_path)


  • for step, data in enumerate(trainloader, start=0):


enumerate()函数

  • 示例:for step, data in enumerate(trainloader, start=0):


  • 作用:将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标


  • 参数:


1.sequence:一个序列、数组或其它对象

2.start:下标起始位置的值


  • 实例


e72a6b4a14c2da46d8607512835e4e55.png


  • predict_y = torch.max(outputs, dim=1)[1]:dim=1选择行中最大的概率值,dim=0选择列最大的概率值。[1]表示切片,因为torch.max会返回两个数值,第一个是这个概率值,第二个是🌈序号。正常可以这么写: _, predicted = torch.max(outputs.data,dim=1)


  • accuracy = (predict_y == test_label).sum().item() / test_label.size(0):如果正确的预测累加,通过item转换为数值,除以总的测试长度,得到正确的结果


  • torch.save(net.state_dict(),save_path):保存权重文件(.pth)


测试


import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet
transform = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
classes = ('plane', 'car', 'bird', 'cat', 'deer'
           'dog', 'frog', 'horse', 'ship', 'truck')
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))
img = Image.open('data/plane.png')
img = transform(img)
img = torch.unsqueeze(img, dim=0)
with torch.no_grad():
    outputs = net(img)
    # predict = torch.max(outputs,dim=1)[1].data.numpy()
    predict = torch.softmax(outputs, dim=1)
print(predict)
# print(classes[int(predict)])


  • net.load_state_dict(torch.load('Lenet.pth')):加载权重
  • with torch.no_grad()::不反向传播计算梯度,减少计算量
  • print(classes[int(predict)]):与class联用,通过索引直接输出标签
相关文章
|
3月前
|
机器学习/深度学习 人工智能 算法
AI 基础知识从 0.6 到 0.7—— 彻底拆解深度神经网络训练的五大核心步骤
本文以一个经典的PyTorch手写数字识别代码示例为引子,深入剖析了简洁代码背后隐藏的深度神经网络(DNN)训练全过程。
831 56
|
1月前
|
机器学习/深度学习 数据可视化 网络架构
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
PINNs训练难因多目标优化易失衡。通过设计硬约束网络架构,将初始与边界条件内嵌于模型输出,可自动满足约束,仅需优化方程残差,简化训练过程,提升稳定性与精度,适用于气候、生物医学等高要求仿真场景。
289 4
PINN训练新思路:把初始条件和边界约束嵌入网络架构,解决多目标优化难题
|
7月前
|
机器学习/深度学习 存储 算法
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
反向传播算法虽是深度学习基石,但面临内存消耗大和并行扩展受限的问题。近期,牛津大学等机构提出NoProp方法,通过扩散模型概念,将训练重塑为分层去噪任务,无需全局前向或反向传播。NoProp包含三种变体(DT、CT、FM),具备低内存占用与高效训练优势,在CIFAR-10等数据集上达到与传统方法相当的性能。其层间解耦特性支持分布式并行训练,为无梯度深度学习提供了新方向。
298 1
NoProp:无需反向传播,基于去噪原理的非全局梯度传播神经网络训练,可大幅降低内存消耗
|
9月前
|
机器学习/深度学习 文件存储 异构计算
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
1044 18
YOLOv11改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
|
9月前
|
机器学习/深度学习 数据可视化 API
DeepSeek生成对抗网络(GAN)的训练与应用
生成对抗网络(GANs)是深度学习的重要技术,能生成逼真的图像、音频和文本数据。通过生成器和判别器的对抗训练,GANs实现高质量数据生成。DeepSeek提供强大工具和API,简化GAN的训练与应用。本文介绍如何使用DeepSeek构建、训练GAN,并通过代码示例帮助掌握相关技巧,涵盖模型定义、训练过程及图像生成等环节。
|
9月前
|
机器学习/深度学习 文件存储 异构计算
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
RT-DETR改进策略【模型轻量化】| 替换骨干网络为EfficientNet v2,加速训练,快速收敛
223 1
|
11月前
|
人工智能 搜索推荐 决策智能
不靠更复杂的策略,仅凭和大模型训练对齐,零样本零经验单LLM调用,成为网络任务智能体新SOTA
近期研究通过调整网络智能体的观察和动作空间,使其与大型语言模型(LLM)的能力对齐,显著提升了基于LLM的网络智能体性能。AgentOccam智能体在WebArena基准上超越了先前方法,成功率提升26.6个点(+161%)。该研究强调了与LLM训练目标一致的重要性,为网络任务自动化提供了新思路,但也指出其性能受限于LLM能力及任务复杂度。论文链接:https://arxiv.org/abs/2410.13825。
216 12
|
11月前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
264 17
|
11月前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将从网络安全漏洞、加密技术和安全意识三个方面进行探讨,旨在提高读者对网络安全的认识和防范能力。通过分析常见的网络安全漏洞,介绍加密技术的基本原理和应用,以及强调安全意识的重要性,帮助读者更好地保护自己的网络信息安全。
223 10
|
11月前
|
存储 SQL 安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
随着互联网的普及,网络安全问题日益突出。本文将介绍网络安全的重要性,分析常见的网络安全漏洞及其危害,探讨加密技术在保障网络安全中的作用,并强调提高安全意识的必要性。通过本文的学习,读者将了解网络安全的基本概念和应对策略,提升个人和组织的网络安全防护能力。