怎么调用pytorch中mnist数据集

简介: 怎么调用pytorch中mnist数据集

问题

怎么调用pytorch中mnist数据集


方法

MNIST数据集介绍

MNIST数据集是NIST(National Institute of Standards and Technology,美国国家标准与技术研究所)数据集的一个子集,MNIST 数据集主要包括四个文件,训练集train一共包含了 60000 张图像和标签,而测试集一共包含了 10000 张图像和标签。

idx3表示3维,ubyte表示是以字节的形式进行存储的,t10k表示10000张测试图片(test10000),每张图片是一个28*28像素点的0 ~ 9的灰质手写数字图片,黑底白字,图像像素值为0 ~ 255,越大该点越白。


数据下载和读取

导入PyTorch的两个核心库torch和torchvision,这两个库基本包含了PyTorch会用到的许多方法和函数,其他库为下面所需要的一些辅助库。

import gzip

import os

import torch

import torchvision

import numpy as np

from PIL import Image

from matplotlib import pyplot as plt

from torchvision import datasets, transforms

from torch.utils.data import DataLoader, Dataset

import datasets是为了方便自动下载数据集,可以下载多种数据集,如MNIST、ImageNet、CIFAR10等。

import transforms是pytorch中的图像预处理库,一般用Compose把多个步骤整合到一起。相关详情见:transforms.Compose()函数

使用Pytorch自带的库函数

导入MNIST数据集代码:

train_data = datasets.MNIST(

          root="./data/",

          train=True,

          transform=transforms.To

通过重构Dataset类读取特定的MNIST数据或者制作自己的MNIST数据集

① 读取MNIST文件夹下processed文件中的training.pt、test.pt数据集

class Data_Loader(Dataset):

  def __init__(self, root, transform=None):

      self.data, self.targets = torch.load(root)#采用torch.load进行读取,读取之后的结果为torch.Tensor形式

self.transform = transform

  def __getitem__(self, index):

      img, target = self.data[index], int(self.targets[index])

      img = Image.fromarray(img.numpy(), mode='L')

      if self.transform is not None:

          img = self.transform(img)

      img = transforms.ToTensor()(img)

      return img, target

  def __len__(self):

      return len(self.data)

接下来,调用我们自定义的Data_Loader类来读取数据集:

# root 为training.pt、test.pt文件所在的绝对路径

train_data = Data_Loader(root='./mnist/MNIST/processed/training.pt', transform= None)

test_data = Data_Loader(root='./mnist/MNIST/processed/test.pt', transform= None)

再使用torch.utils.data.DataLoader对train_data和test_data进行加载,展示。

② 读取MNIST文件夹下raw文件中的数据集

class Data_Loader(Dataset):

  def __init__(self, folder, data_name, label_name, transform=None):

      (train_set, train_labels) = load_data(folder, data_name, label_name)

      self.train_set = train_set

      self.train_labels = train_labels

      self.transform = transform

  def __getitem__(self, index):

      img, target = self.train_set[index], int(self.train_labels[index])

      if self.transform is not None:

          img = self.transform(img)

      return img, target

  def __len__(self):

      return len(self.train_set)

def load_data(data_folder, data_name, label_name):

  with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath:  # rb表示的是读取二进制数据

      y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

  with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:

      x_train = np.frombuffer(

          imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)

  return x_train, y_train

接下来,调用我们自定义的Data_Loader类来读取数据集:

#folder:MNIST数据集中raw文件的绝对路径

# 读取MNIST数据集中的训练集

train_data = Data_Loader('./MNIST/MNIST/raw', "train-images-idx3-ubyte.gz",

                         "train-labels-idx1-ubyte.gz", transform=transforms.ToTensor())

# 读取MNIST数据集中的测试集

test_data = Data_Loader('./MNIST/MNIST/raw', "t10k-images-idx3-ubyte.gz",

                         "t10k-labels-idx1-ubyte.gz", transform=transforms.ToTensor())

再使用torch.utils.data.DataLoader对train_data和test_data进行加载,展示。

③ 直接读取MNIST数据集


总结

mnist数据集是一个计算机视觉数据集,训练集包括六万张图片,测试集一万张图片,并且已经进行过预处理和格式化。这些数据集有两个功能:一个功能是提供了大量的数据作为训练集和验证集,为一些学习人员提供了丰富的样 本信息一一这一点很宝贵,要知道在深度学习领域要想在一个方面有比较深的研究成果, 除了需要具备一定的网络设计和调优能力以外,还有一个就是要有丰富的训练样本。另一 个功能就是可以形成一个在业内相对有普适性的 Benchmark 比对项目一一既然大家用的数 据集都是一样的,那么每个人设计出来的网络就可以在这些数据集上不断互相比较,从而 验证谁家的网络设计得识别率更高。

目录
相关文章
|
3月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
【PyTorch实战演练】AlexNet网络模型构建并使用Cifar10数据集进行批量训练(附代码)
70 0
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
【PyTorch实战演练】使用Cifar10数据集训练LeNet5网络并实现图像分类(附代码)
75 0
|
23天前
|
机器学习/深度学习 数据可视化 PyTorch
利用PyTorch实现基于MNIST数据集的手写数字识别
利用PyTorch实现基于MNIST数据集的手写数字识别
23 2
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
使用PyTorch加载数据集:简单指南
使用PyTorch加载数据集:简单指南
使用PyTorch加载数据集:简单指南
|
3月前
|
机器学习/深度学习 算法 PyTorch
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
|
5月前
|
机器学习/深度学习 自然语言处理 PyTorch
PyTorch搭建RNN联合嵌入模型(LSTM GRU)实现视觉问答(VQA)实战(超详细 附数据集和源码)
PyTorch搭建RNN联合嵌入模型(LSTM GRU)实现视觉问答(VQA)实战(超详细 附数据集和源码)
74 1
|
5月前
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch实现DCGAN(生成对抗网络)生成新的假名人照片实战(附源码和数据集)
PyTorch实现DCGAN(生成对抗网络)生成新的假名人照片实战(附源码和数据集)
58 1
|
5月前
|
机器学习/深度学习 搜索推荐 数据可视化
PyTorch搭建基于图神经网络(GCN)的天气推荐系统(附源码和数据集)
PyTorch搭建基于图神经网络(GCN)的天气推荐系统(附源码和数据集)
94 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
67 2

相关实验场景

更多