PyTorch框架和MNIST数据集

简介: 6月更文挑战20天
  1. PyTorch框架:
    PyTorch是一个开源的机器学习库,由Facebook的人工智能研究团队开发,适用于Python程序,支持强大的GPU加速。PyTorch提供了两个主要功能:一是张量计算(如NumPy),二是基于这些张量的自动求导机制。这使得PyTorch非常适合进行深度学习研究和应用。PyTorch的动态计算图(也称为即时执行)是其与其他深度学习框架(如TensorFlow)的主要区别,这使得它在研究和调试方面更为灵活。
  2. MNIST数据集:
    MNIST(Modified National Institute of Standards and Technology database)是一个广泛使用的手写数字识别数据集。它包含60,000个训练样本和10,000个测试样本。每个样本都是一张28x28像素的灰度图像,代表一个0到9之间的数字。MNIST数据集通常被用作入门级的机器学习和深度学习项目,因为它既不太简单也不太复杂,非常适合初步学习和实践。
    下面是使用PyTorch框架来训练一个简单的神经网络进行MNIST手写数字识别的代码示例:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms
    # 定义超参数
    BATCH_SIZE = 64
    EPOCHS = 5
    LEARNING_RATE = 0.01
    # 数据预处理:将数据转换为torch.FloatTensor,并标准化
    transform = transforms.Compose([
     transforms.ToTensor(),
     transforms.Normalize((0.1307,), (0.3081,))
    ])
    # 下载并加载训练数据
    train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    # 下载并加载测试数据
    test_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)
    # 定义一个简单的神经网络模型
    class Net(nn.Module):
     def __init__(self):
         super(Net, self).__init__()
         self.fc1 = nn.Linear(28*28, 500)
         self.fc2 = nn.Linear(500, 10)
     def forward(self, x):
         x = x.view(-1, 28*28)
         x = F.relu(self.fc1(x))
         x = self.fc2(x)
         return x
    # 实例化模型
    model = Net()
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
    # 训练模型
    for epoch in range(EPOCHS):
     for images, labels in train_loader:
         # 前向传播
         outputs = model(images)
         loss = criterion(outputs, labels)
    
         # 反向传播和优化
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
    
     print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {loss.item()}')
    # 测试模型
    model.eval()  # 将模型设置为评估模式
    with torch.no_grad():  # 在这个with下,所有计算得出的tensor都不会计算梯度,也就是不会进行反向传播
     correct = 0
     total = 0
     for images, labels in test_loader:
         outputs = model(images)
         _, predicted = torch.max(outputs.data, 1)
         total += labels.size(0)
         correct += (predicted == labels).sum().item()
     print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}%')
    
相关文章
|
2月前
|
存储 人工智能 PyTorch
基于PyTorch/XLA的高效分布式训练框架
基于PyTorch/XLA的高效分布式训练框架
225 2
|
1月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
|
1月前
|
机器学习/深度学习 自然语言处理 PyTorch
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
|
1月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
1月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
2月前
|
机器学习/深度学习 数据可视化 PyTorch
TensorFlow与PyTorch框架的深入对比:特性、优势与应用场景
【5月更文挑战第4天】本文对比了深度学习主流框架TensorFlow和PyTorch的特性、优势及应用场景。TensorFlow以其静态计算图、高性能及TensorBoard可视化工具适合大规模数据处理和复杂模型,但学习曲线较陡峭。PyTorch则以动态计算图、易用性和灵活性见长,便于研究和原型开发,但在性能和部署上有局限。选择框架应根据具体需求和场景。
|
2月前
|
PyTorch 算法框架/工具 Python
【pytorch框架】对模型知识的基本了解
【pytorch框架】对模型知识的基本了解
|
2月前
|
机器学习/深度学习 数据可视化 PyTorch
利用PyTorch实现基于MNIST数据集的手写数字识别
利用PyTorch实现基于MNIST数据集的手写数字识别
69 2
|
2月前
|
机器学习/深度学习 负载均衡 PyTorch
PyTorch分布式训练:加速大规模数据集的处理
【4月更文挑战第18天】PyTorch分布式训练加速大规模数据集处理,通过数据并行和模型并行提升训练效率。`torch.distributed`提供底层IPC与同步,适合定制化需求;`DistributedDataParallel`则简化并行过程。实际应用注意数据划分、通信开销、负载均衡及错误处理。借助PyTorch分布式工具,可高效应对深度学习的计算挑战,未来潜力无限。
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】