怎么调用pytorch中mnist数据集

简介: 怎么调用pytorch中mnist数据集

问题

怎么调用pytorch中mnist数据集


方法

MNIST数据集介绍

MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集主要包括四个文件,训练集train一共包含了 60000 张图像和标签,而测试集一共包含了 10000 张图像和标签。

idx3表示3维,ubyte表示是以字节的形式进行存储的,t10k表示10000张测试图片(test10000),每张图片是一个28*28像素点的0 ~ 9的灰质手写数字图片,黑底白字,图像像素值为0 ~ 255,越大该点越白。


数据下载和读取

导入PyTorch的两个核心库torch和torchvision,这两个库基本包含了PyTorch会用到的许多方法和函数,其他库为下面所需要的一些辅助库。

import gzip

import os

import torch

import torchvision

import numpy as np

from PIL import Image

from matplotlib import pyplot as plt

from torchvision import datasets, transforms

from torch.utils.data import DataLoader, Dataset

import datasets是为了方便自动下载数据集,可以下载多种数据集,如MNIST、ImageNet、CIFAR10等。

import transforms是pytorch中的图像预处理库,一般用Compose把多个步骤整合到一起。相关详情见:transforms.Compose()函数

使用Pytorch自带的库函数

导入MNIST数据集代码:

train_data = datasets.MNIST(

          root="./data/",

          train=True,

          transform=transforms.To

通过重构Dataset类读取特定的MNIST数据或者制作自己的MNIST数据集

① 读取MNIST文件夹下processed文件中的training.pt、test.pt数据集

class Data_Loader(Dataset):

  def __init__(self, root, transform=None):

      self.data, self.targets = torch.load(root)#采用torch.load进行读取,读取之后的结果为torch.Tensor形式

self.transform = transform

  def __getitem__(self, index):

      img, target = self.data[index], int(self.targets[index])

      img = Image.fromarray(img.numpy(), mode='L')

      if self.transform is not None:

          img = self.transform(img)

      img = transforms.ToTensor()(img)

      return img, target

  def __len__(self):

      return len(self.data)

接下来,调用我们自定义的Data_Loader类来读取数据集:

# root 为training.pt、test.pt文件所在的绝对路径

train_data = Data_Loader(root='./mnist/MNIST/processed/training.pt', transform= None)

test_data = Data_Loader(root='./mnist/MNIST/processed/test.pt', transform= None)

再使用torch.utils.data.DataLoader对train_data和test_data进行加载,展示。

② 读取MNIST文件夹下raw文件中的数据集

class Data_Loader(Dataset):

  def __init__(self, folder, data_name, label_name, transform=None):

      (train_set, train_labels) = load_data(folder, data_name, label_name)

      self.train_set = train_set

      self.train_labels = train_labels

      self.transform = transform

  def __getitem__(self, index):

      img, target = self.train_set[index], int(self.train_labels[index])

      if self.transform is not None:

          img = self.transform(img)

      return img, target

  def __len__(self):

      return len(self.train_set)

def load_data(data_folder, data_name, label_name):

  with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath:  # rb表示的是读取二进制数据

      y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

  with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:

      x_train = np.frombuffer(

          imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

  return x_train, y_train

接下来,调用我们自定义的Data_Loader类来读取数据集:

#folder:MNIST数据集中raw文件的绝对路径

# 读取MNIST数据集中的训练集

train_data = Data_Loader('./MNIST/MNIST/raw', "train-images-idx3-ubyte.gz",

                         "train-labels-idx1-ubyte.gz", transform=transforms.ToTensor())

# 读取MNIST数据集中的测试集

test_data = Data_Loader('./MNIST/MNIST/raw', "t10k-images-idx3-ubyte.gz",

                         "t10k-labels-idx1-ubyte.gz", transform=transforms.ToTensor())

再使用torch.utils.data.DataLoader对train_data和test_data进行加载,展示。

③ 直接读取MNIST数据集


总结

mnist数据集是一个计算机视觉数据集,训练集包括六万张图片,测试集一万张图片,并且已经进行过预处理和格式化。这些数据集有两个功能:一个功能是提供了大量的数据作为训练集和验证集,为一些学习人员提供了丰富的样 本信息一一这一点很宝贵,要知道在深度学习领域要想在一个方面有比较深的研究成果, 除了需要具备一定的网络设计和调优能力以外,还有一个就是要有丰富的训练样本。另一 个功能就是可以形成一个在业内相对有普适性的 Benchmark 比对项目一一既然大家用的数 据集都是一样的,那么每个人设计出来的网络就可以在这些数据集上不断互相比较,从而 验证谁家的网络设计得识别率更高。

目录
打赏
0
0
0
0
14
分享
相关文章
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
601 0
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
文章介绍了如何在CUDA 12.1、CUDNN 8.9和PyTorch 2.3.1环境下实现自定义数据集的训练,包括环境配置、预览结果和核心步骤,以及遇到问题的解决方法和参考链接。
236 4
基于CUDA12.1+CUDNN8.9+PYTORCH2.3.1,实现自定义数据集训练
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】38. Pytorch实战案例:梯度下降、随机梯度下降、小批量随机梯度下降3种优化算法对比【含数据集与源码】
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
【从零开始学习深度学习】34. Pytorch-RNN项目实战:RNN创作歌词案例--使用周杰伦专辑歌词训练模型并创作歌曲【含数据集与源码】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
PyTorch分布式训练:加速大规模数据集的处理
【4月更文挑战第18天】PyTorch分布式训练加速大规模数据集处理,通过数据并行和模型并行提升训练效率。`torch.distributed`提供底层IPC与同步,适合定制化需求;`DistributedDataParallel`则简化并行过程。实际应用注意数据划分、通信开销、负载均衡及错误处理。借助PyTorch分布式工具,可高效应对深度学习的计算挑战,未来潜力无限。
基于昇腾用PyTorch实现传统CTR模型WideDeep网络
本文介绍了如何在昇腾平台上使用PyTorch实现经典的WideDeep网络模型,以处理推荐系统中的点击率(CTR)预测问题。
196 66
AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等