LeNet网络搭建与基本训练流程

简介: LeNet网络搭建与基本训练流程

模型


2de6b682a217a754ea7bbef5c366d493.png


class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()   # 解决继承父类中出现的一系列问题
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))       # 输入(3,32,32) 输出(16,28,28)
        x = self.pool1(x)               # 输出(16,14,14)
        x = F.relu(self.conv2(x))       # 输出(32,10,10)
        x = self.pool2(x)               # 输出(32,5,5)
        x = x.view(-1, 32*5*5)          # 输出(32*5*5),batch=-1设为动态调整这个维度上的元素的个数,以保证元素的总数不变
        x = F.relu(self.fc1(x))         # 输出(120)
        x = F.relu(self.fc2(x))         # 输出(84)
        x = self.fc3(x)                 # 输出(10)
        return x


预处理


transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


1.转换图片数据为pytorch中的Tensor格式

2.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))将数据转换为标准正态分布,即逐个c h a n n e l channelchannel的对图像进行标准化(均值变为0 00,标准差为1 11),可以加快模型的收敛


加载数据集


# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36, shuffle=True, num_workers=0)
# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000, shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_label = test_data_iter.next()
classes = ('plane', 'car', 'bird', 'cat', 'deer'
           'dog', 'frog', 'horse', 'ship', 'truck')


加载cifar10数据集,然后取出测试集的图片和标签


3c921f8fb15391e781d42960bc342f4b.png


image-20220707211309072.png


训练


1.加载模型、定义损失函数、优化器


net = LeNet()
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)


  • nn.CrossEntropyLoss():使用交叉熵损失函数
  • optim.Adam(net.parameters(), lr=0.001):net.parameters()是将net的参数都丢进优化器里
  • net = LeNet():注意LeNet后面有一个()


5cb30f14528e0673052d788ad29f085f.gif


2.训练循环


def train_process():
    for epoch in range(10):
        running_loss = 0.0
        for step, data in enumerate(trainloader, start=0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if step % 500 ==499:
                with torch.no_grad():
                    outputs = net(test_image)
                    predict_y = torch.max(outputs, dim=1)[1]
                    accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
                    print('[%d, %5d] train_loss: %.3f test_accuracy:%.3f'
                          %(epoch+1, step+1, running_loss/500, accuracy))
                    running_loss = 0.0
    print("Finished Training")
train_process()
save_path = 'Lenet.pth'
torch.save(net.state_dict(),save_path)


  • for step, data in enumerate(trainloader, start=0):


enumerate()函数

  • 示例:for step, data in enumerate(trainloader, start=0):


  • 作用:将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标


  • 参数:


1.sequence:一个序列、数组或其它对象

2.start:下标起始位置的值


  • 实例


e72a6b4a14c2da46d8607512835e4e55.png


  • predict_y = torch.max(outputs, dim=1)[1]:dim=1选择行中最大的概率值,dim=0选择列最大的概率值。[1]表示切片,因为torch.max会返回两个数值,第一个是这个概率值,第二个是🌈序号。正常可以这么写: _, predicted = torch.max(outputs.data,dim=1)


  • accuracy = (predict_y == test_label).sum().item() / test_label.size(0):如果正确的预测累加,通过item转换为数值,除以总的测试长度,得到正确的结果


  • torch.save(net.state_dict(),save_path):保存权重文件(.pth)


测试


import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNet
transform = transforms.Compose([transforms.Resize((32,32)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
classes = ('plane', 'car', 'bird', 'cat', 'deer'
           'dog', 'frog', 'horse', 'ship', 'truck')
net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))
img = Image.open('data/plane.png')
img = transform(img)
img = torch.unsqueeze(img, dim=0)
with torch.no_grad():
    outputs = net(img)
    # predict = torch.max(outputs,dim=1)[1].data.numpy()
    predict = torch.softmax(outputs, dim=1)
print(predict)
# print(classes[int(predict)])


  • net.load_state_dict(torch.load('Lenet.pth')):加载权重
  • with torch.no_grad()::不反向传播计算梯度,减少计算量
  • print(classes[int(predict)]):与class联用,通过索引直接输出标签
相关文章
|
14天前
|
机器学习/深度学习
神经网络与深度学习---验证集(测试集)准确率高于训练集准确率的原因
本文分析了神经网络中验证集(测试集)准确率高于训练集准确率的四个可能原因,包括数据集大小和分布不均、模型正则化过度、批处理后准确率计算时机不同,以及训练集预处理过度导致分布变化。
|
2天前
|
机器学习/深度学习 数据采集 数据可视化
深度学习实践:构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行分类
本文详细介绍如何使用PyTorch构建并训练卷积神经网络(CNN)对CIFAR-10数据集进行图像分类。从数据预处理、模型定义到训练过程及结果可视化,文章全面展示了深度学习项目的全流程。通过实际操作,读者可以深入了解CNN在图像分类任务中的应用,并掌握PyTorch的基本使用方法。希望本文为您的深度学习项目提供有价值的参考与启示。
|
8天前
|
网络协议 C语言
C语言 网络编程(十一)TCP通信创建流程---服务端
在服务器流程中,新增了绑定IP地址与端口号、建立监听队列及接受连接并创建新文件描述符等步骤。`bind`函数用于绑定IP地址与端口,`listen`函数建立监听队列并设置监听状态,`accept`函数则接受连接请求并创建新的文件描述符用于数据传输。套接字状态包括关闭(CLOSED)、同步发送(SYN-SENT)、同步接收(SYN-RECEIVE)和已建立连接(ESTABLISHED)。示例代码展示了TCP服务端程序如何初始化socket、绑定地址、监听连接请求以及接收和发送数据。
|
8天前
|
C语言
C语言 网络编程(七)UDP通信创建流程
本文档详细介绍了使用 UDP 协议进行通信的过程,包括创建套接字、发送与接收消息等关键步骤。首先,通过 `socket()` 函数创建套接字,并设置相应的参数。接着,使用 `sendto()` 函数向指定地址发送数据。为了绑定地址,需要调用 `bind()` 函数。接收端则通过 `recvfrom()` 函数接收数据并获取发送方的地址信息。文档还提供了完整的代码示例,展示了如何实现 UDP 的发送端和服务端功能。
|
8天前
|
网络协议 C语言
C语言 网络编程(十)TCP通信创建流程---客户端
在TCP通信中,客户端需通过一系列步骤与服务器建立连接并进行数据传输。首先使用 `socket()` 函数创建一个流式套接字,然后通过 `connect()` 函数连接服务器。连接成功后,可以使用 `send()` 和 `recv()` 函数进行数据发送和接收。最后展示了一个完整的客户端示例代码,实现了与服务器的通信过程。
|
1月前
|
机器学习/深度学习
CNN网络编译和训练
【8月更文挑战第10天】CNN网络编译和训练。
62 20
|
11天前
|
测试技术 持续交付 开发者
Xamarin 高效移动应用测试最佳实践大揭秘,从框架选择到持续集成,让你的应用质量无敌!
【8月更文挑战第31天】竞争激烈的移动应用市场,Xamarin 作为一款优秀的跨平台开发工具,提供了包括单元测试、集成测试及 UI 测试在内的全面测试方案。借助 Xamarin.UITest 框架,开发者能便捷地用 C# 编写测试案例,如登录功能测试;通过 Xamarin 模拟框架,则可在无需真实设备的情况下模拟各种环境测试应用表现;Xamarin.TestCloud 则支持在真实设备上执行自动化测试,确保应用兼容性。结合持续集成与部署策略,进一步提升测试效率与应用质量。掌握 Xamarin 的测试最佳实践,对确保应用稳定性和优化用户体验至关重要。
24 0
|
14天前
|
机器学习/深度学习 PyTorch 测试技术
深度学习入门:使用 PyTorch 构建和训练你的第一个神经网络
【8月更文第29天】深度学习是机器学习的一个分支,它利用多层非线性处理单元(即神经网络)来解决复杂的模式识别问题。PyTorch 是一个强大的深度学习框架,它提供了灵活的 API 和动态计算图,非常适合初学者和研究者使用。
25 0
|
1月前
|
机器学习/深度学习 API 算法框架/工具
【Tensorflow+keras】Keras API两种训练GAN网络的方式
使用Keras API以两种不同方式训练条件生成对抗网络(CGAN)的示例代码:一种是使用train_on_batch方法,另一种是使用tf.GradientTape进行自定义训练循环。
24 5
|
14天前
|
机器学习/深度学习 自然语言处理 TensorFlow
深度学习的奥秘:探索神经网络的构建与训练
【8月更文挑战第28天】本文旨在揭开深度学习的神秘面纱,通过浅显易懂的语言和直观的代码示例,引导读者理解并实践神经网络的构建与训练。我们将从基础概念出发,逐步深入到模型的实际应用,让初学者也能轻松掌握深度学习的核心技能。