构建你的第一个PyTorch神经网络模型

简介: 【4月更文挑战第17天】本文介绍了如何使用PyTorch构建和训练第一个神经网络模型。首先,准备数据集,如MNIST。接着,自定义神经网络模型`SimpleNet`,包含两个全连接层和ReLU激活函数。然后,定义交叉熵损失函数和SGD优化器。训练模型涉及多次迭代,计算损失、反向传播和参数更新。最后,测试模型性能,计算测试集上的准确率。这是一个基础的深度学习入门示例,为进一步探索复杂项目打下基础。

在深度学习的世界中,PyTorch是一个广受欢迎的开源库,它提供了强大的张量计算功能和灵活的神经网络构建能力。无论你是机器学习的新手还是老手,都可以通过PyTorch快速构建和训练神经网络模型。在这篇文章中,我们将引导你一步步构建你的第一个PyTorch神经网络模型。

一、准备数据集

在开始构建模型之前,你需要准备一个用于训练和测试的数据集。这可以是一个现成的数据集,比如MNIST手写数字数据集,也可以是你自己收集的数据。在本例中,我们将使用PyTorch内置的MNIST数据集。

二、构建神经网络模型

在PyTorch中,你可以通过继承nn.Module类来创建自己的神经网络模型。下面是一个简单的全连接神经网络的例子:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 512)  # 输入层到隐藏层,784是输入特征数,512是隐藏层神经元数
        self.fc2 = nn.Linear(512, 10)   # 隐藏层到输出层,10是输出类别数

    def forward(self, x):
        x = x.view(-1, 784)  # 将输入数据展平为一维向量
        x = F.relu(self.fc1(x))  # 通过第一个全连接层,并使用ReLU激活函数
        x = self.fc2(x)  # 通过第二个全连接层,得到输出
        return x

在上面的代码中,我们定义了一个名为SimpleNet的神经网络类。这个类有两个全连接层(fc1fc2),并使用ReLU激活函数。在forward方法中,我们定义了数据在网络中的前向传播过程。

三、定义损失函数和优化器

接下来,我们需要定义一个损失函数和一个优化器。损失函数用于衡量模型预测结果与实际结果之间的差距,而优化器则用于根据损失函数的梯度更新模型的参数。

criterion = nn.CrossEntropyLoss()  # 定义交叉熵损失函数
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 定义随机梯度下降优化器,学习率为0.01

在上面的代码中,我们使用了交叉熵损失函数(nn.CrossEntropyLoss)和随机梯度下降优化器(torch.optim.SGD)。model.parameters()返回模型中所有可训练的参数,lr=0.01设置了学习率为0.01。

四、训练模型

现在我们可以开始训练模型了。在训练过程中,我们将多次迭代数据集,每次迭代中计算损失、反向传播梯度并更新模型参数。

num_epochs = 5  # 训练轮数
for epoch in range(num_epochs):
    # 前向传播
    outputs = model(inputs)
    loss = criterion(outputs, labels)

    # 反向传播和优化
    optimizer.zero_grad()  # 清零梯度缓存
    loss.backward()  # 反向传播计算梯度
    optimizer.step()  # 更新模型参数

    # 打印统计信息
    if (epoch+1) % 100 == 0:
        print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

在上面的代码中,我们设置了训练轮数为5。在每个训练轮次中,我们首先进行前向传播,计算模型的输出和损失。然后,我们清零梯度缓存,进行反向传播计算梯度,并使用优化器更新模型参数。最后,我们打印出当前的轮次和损失值,以便监控训练过程。

五、测试模型

训练完成后,我们可以使用测试数据集来评估模型的性能。通常,我们会计算模型在测试集上的准确率等指标。

# 将模型设置为评估模式
model.eval()
with torch.no_grad():  # 不计算梯度,节省计算资源
    correct = 0
    total = 0
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)  # 获取预测结果中概率最大的类别
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: {} %'.format(100 * correct / total))

在上面的代码中,我们首先将模型设置为评估模式(model.eval()),然后关闭梯度计算(torch.no_grad()),以避免在测试过程中不必要的计算开销。接着,我们遍历测试数据集,计算模型的预测结果,并统计预测正确的样本数。最后,我们计算并打印出模型在测试集上的准确率。

六、总结

通过以上步骤,你已经成功构建了你的第一个PyTorch神经网络模型,并完成了训练和测试过程。当然,这只是一个简单的示例,实际的深度学习项目可能涉及更复杂的网络结构、损失函数和优化器选择,以及更多的数据预处理和模型调优步骤。希望这篇文章能为你提供一个良好的起点,帮助你进一步探索深度学习的世界。

相关文章
|
2月前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
136 1
|
1月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
|
2月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
97 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
1月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
2月前
|
机器学习/深度学习 传感器 算法
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
135 2
|
2月前
|
机器学习/深度学习 并行计算 算法
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
|
2月前
|
机器学习/深度学习 算法 PyTorch
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
|
2月前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
103 0
|
3月前
|
算法 安全 网络安全
【多智能体系统】遭受DoS攻击的网络物理多智能体系统的弹性模型预测控制MPC研究(Simulink仿真实现)
【多智能体系统】遭受DoS攻击的网络物理多智能体系统的弹性模型预测控制MPC研究(Simulink仿真实现)
145 0
|
11月前
|
SQL 安全 网络安全
网络安全与信息安全:知识分享####
【10月更文挑战第21天】 随着数字化时代的快速发展,网络安全和信息安全已成为个人和企业不可忽视的关键问题。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议,帮助读者提高自身的网络安全防护能力。 ####
245 17

热门文章

最新文章

推荐镜像

更多