PyTorch框架和MNIST数据集

简介: 6月更文挑战20天
  1. PyTorch框架:
    PyTorch是一个开源的机器学习库,由Facebook的人工智能研究团队开发,适用于Python程序,支持强大的GPU加速。PyTorch提供了两个主要功能:一是张量计算(如NumPy),二是基于这些张量的自动求导机制。这使得PyTorch非常适合进行深度学习研究和应用。PyTorch的动态计算图(也称为即时执行)是其与其他深度学习框架(如TensorFlow)的主要区别,这使得它在研究和调试方面更为灵活。
  2. MNIST数据集:
    MNIST(Modified National Institute of Standards and Technology database)是一个广泛使用的手写数字识别数据集。它包含60,000个训练样本和10,000个测试样本。每个样本都是一张28x28像素的灰度图像,代表一个0到9之间的数字。MNIST数据集通常被用作入门级的机器学习和深度学习项目,因为它既不太简单也不太复杂,非常适合初步学习和实践。
    下面是使用PyTorch框架来训练一个简单的神经网络进行MNIST手写数字识别的代码示例:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms
    # 定义超参数
    BATCH_SIZE = 64
    EPOCHS = 5
    LEARNING_RATE = 0.01
    # 数据预处理:将数据转换为torch.FloatTensor,并标准化
    transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))
    ])
    # 下载并加载训练数据
    train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    # 下载并加载测试数据
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
    # 定义一个简单的神经网络模型
    class Net(nn.Module):
     def __init__(self):
         super(Net, self).__init__()
         self.fc1 = nn.Linear(28*28, 500)
         self.fc2 = nn.Linear(500, 10)
     def forward(self, x):
         x = x.view(-1, 28*28)
         x = F.relu(self.fc1(x))
         x = self.fc2(x)
         return x
    # 实例化模型
    model = Net()
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
    # 训练模型
    for epoch in range(EPOCHS):
     for images, labels in train_loader:
         # 前向传播
         outputs = model(images)
         loss = criterion(outputs, labels)
    
         # 反向传播和优化
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
    
     print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}')
    # 测试模型
    model.eval()  # 将模型设置为评估模式
    with torch.no_grad():  # 在这个with下,所有计算得出的tensor都不会计算梯度,也就是不会进行反向传播
     correct = 0
     total = 0
     for images, labels in test_loader:
         outputs = model(images)
         _, predicted = torch.max(outputs.data, 1)
         total += labels.size(0)
         correct += (predicted == labels).sum().item()
     print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
    
相关文章
|
7天前
|
机器学习/深度学习 监控 PyTorch
深度学习工程实践:PyTorch Lightning与Ignite框架的技术特性对比分析
在深度学习框架的选择上,PyTorch Lightning和Ignite代表了两种不同的技术路线。本文将从技术实现的角度,深入分析这两个框架在实际应用中的差异,为开发者提供客观的技术参考。
25 7
|
2月前
|
并行计算 PyTorch 算法框架/工具
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
文章介绍了如何在CUDA 12.1、CUDNN 8.9和PyTorch 2.3.1环境下实现自定义数据集的训练,包括环境配置、预览结果和核心步骤,以及遇到问题的解决方法和参考链接。
101 4
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
3月前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
76 1
|
3月前
|
UED 开发者
哇塞!Uno Platform 数据绑定超全技巧大揭秘!从基础绑定到高级转换,优化性能让你的开发如虎添翼
【8月更文挑战第31天】在开发过程中,数据绑定是连接数据模型与用户界面的关键环节,可实现数据自动更新。Uno Platform 提供了简洁高效的数据绑定方式,使属性变化时 UI 自动同步更新。通过示例展示了基本绑定方法及使用 `Converter` 转换数据的高级技巧,如将年龄转换为格式化字符串。此外,还可利用 `BindingMode.OneTime` 提升性能。掌握这些技巧能显著提高开发效率并优化用户体验。
60 0
|
3月前
|
机器学习/深度学习 PyTorch TensorFlow
深度学习框架之争:全面解析TensorFlow与PyTorch在功能、易用性和适用场景上的比较,帮助你选择最适合项目的框架
【8月更文挑战第31天】在深度学习领域,选择合适的框架至关重要。本文通过开发图像识别系统的案例,对比了TensorFlow和PyTorch两大主流框架。TensorFlow由Google开发,功能强大,支持多种设备,适合大型项目和工业部署;PyTorch则由Facebook推出,强调灵活性和速度,尤其适用于研究和快速原型开发。通过具体示例代码展示各自特点,并分析其适用场景,帮助读者根据项目需求和个人偏好做出明智选择。
66 0
|
5月前
|
机器学习/深度学习 自然语言处理 PyTorch
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
|
5月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
6月前
|
机器学习/深度学习 数据可视化 PyTorch
TensorFlow与PyTorch框架的深入对比:特性、优势与应用场景
【5月更文挑战第4天】本文对比了深度学习主流框架TensorFlow和PyTorch的特性、优势及应用场景。TensorFlow以其静态计算图、高性能及TensorBoard可视化工具适合大规模数据处理和复杂模型,但学习曲线较陡峭。PyTorch则以动态计算图、易用性和灵活性见长,便于研究和原型开发,但在性能和部署上有局限。选择框架应根据具体需求和场景。