Pytorch中基于MNIST数据的torchvision工具包应用

简介: Pytorch中基于MNIST数据的torchvision工具包应用

一、torchvision简介

torchvision工具包主要包含以下三个部分:

  • models:提供深度学习中各种经典网络的网络结构和预训练好的模型,包括ResNet系列等。
  • datasets:提供常用的数据集加载,设计上继承torch.utils.data.Dataset,主要包括MNIST等数据集。同时datasets下包含这个ImageFolder方法,这个方法的实现和 博主这篇文章代码中的DogCat类(点击打开文章网页) 很相似,可以用来读取用户自己的图像数据集,而非datasets自带的数据集。
  • transforms:提供常用数据预处理操作,主要包括对Tensor和PIL Image 对象的操作。

二、torchvision安装

基于Pytorch中安装torchvision简单详细完整版:点击打开文章网页

三、应用要求和实现流程及注意事项

(1)应用要求:主要是对MNIST数据集图片进行处理,首先自定义操作transforms,然后对每批次图像进行transforms处理,再将该批次的图像拼接成一张网格图像,再保存展示图像。

(2)具体实现流程:按下面代码的括号中的的顺序和注释依次进行理解。

(3)注意事项:注意transforms里面输入图像数据的通道数和设计是否匹配;每批次处理的图像数据的数目大小要明确;对象的迭代结果依旧是对象;tensor数据格式和PIL Image格式的转换。

四、代码及结果

import torch
from torchvision import datasets
import torchvision.transforms as T
from torch.utils.data import DataLoader
import numpy as np
from torchvision.utils import make_grid,save_image
transform = T.Compose([ # 等同于sequential,调用方式也一致,此transforms输入数据类型是PIL Image,输出数据类型是tensor (2)
    T.Resize(224), # 缩放图片,保持长宽比不变,最短边为224像素
    T.CenterCrop(224), # 从图片中间切出224*224的图片
    T.ToTensor(), # 将图片(Image)转换成Tensor,归一化[0,1]
    T.Normalize(mean=[.5],std=[.5]) # ,注意通道数的变化,此时输入数据的通道数为1,数据维度要跟着变化,标准化[-1,1],处理后格式依旧为tensor格式
])
# torchvision.datasets提供常用数据集下载
# root指定数据集下载的路径,若之前没有下载程序会进行自动下载,train=False获取数据集中的测试数据集 (1)
dataset = datasets.MNIST('data/',download=True,train=False,transform=transform)
# print(dataset.data.type()) # 验证dataset对象的数据data属性是tensor
# 加载之前获取的数据集dataset进行批次处理,生成一个可迭代的对象
dataload = DataLoader(dataset,shuffle=True,batch_size=16) # 打乱顺序,每批次数据取16个图像(3)
# iter是用来每一次返回指定对象的迭代器
dataiter = iter(dataload) (4)
# 迭代器迭代一次后用next方法返回一个迭代对象,对象的数据属性data也就是下面的next(dataiter)[0]是tensor数据格式,将每批次图像拼接成4*4网格图片,且通道数量可转成3通道数,make_grad返回tensor格式数据
img = make_grid(next(dataiter)[0],4) # (5)
# 保存图片,输入要求tensor格式数据
save_image(img,'number.png') # (6)
# 展示图片,可以先转成PILImage图像格式
img = T.ToPILImage()(img) # (7)
img.show() # (8)


相关文章
|
2月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】自调整学习率实例应用(附代码)
【PyTorch实战演练】自调整学习率实例应用(附代码)
148 0
|
2月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
9月前
|
机器学习/深度学习 数据采集 PyTorch
使用自定义 PyTorch 运算符优化深度学习数据输入管道
使用自定义 PyTorch 运算符优化深度学习数据输入管道
46 0
|
8月前
|
数据可视化 PyTorch 算法框架/工具
Pytorch可视化Visdom、tensorboardX和Torchvision
Pytorch可视化Visdom、tensorboardX和Torchvision
54 0
|
2月前
|
数据采集 PyTorch 算法框架/工具
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
PyTorch基础之数据模块Dataset、DataLoader用法详解(附源码)
647 0
|
6天前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch框架和MNIST数据集
6月更文挑战20天
39 2
|
13天前
|
PyTorch 算法框架/工具
win10下安装pytorch,torchvision遇到的bug
win10下安装pytorch,torchvision遇到的bug
16 1
|
24天前
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch深度学习框架入门与应用
PyTorch 提供了丰富的工具和 GPU 加速功能,便于构建和训练神经网络。基础包括:1) 张量,类似 NumPy,支持 GPU 计算;2) 自动微分,方便计算梯度;3) 内置神经网络模块 `nn`。PyTorch 还支持数据并行、自定义层、模型保存加载、模型可视化和剪枝量化等进阶用法。通过不断学习,你将能掌握更多高级功能。【6月更文挑战第6天】
33 8
|
2月前
|
机器学习/深度学习 PyTorch API
|
2月前
|
机器学习/深度学习 PyTorch TensorFlow
TensorFlow与PyTorch在Python面试中的对比与应用
【4月更文挑战第16天】这篇博客探讨了Python面试中TensorFlow和PyTorch的常见问题,包括框架基础操作、自动求梯度与反向传播、数据加载与预处理。易错点包括混淆框架API、动态图与静态图的理解、GPU加速的利用、模型保存恢复以及版本兼容性。通过掌握这些问题和解决策略,面试者能展示其深度学习框架技能。
54 9