一、torchvision简介
torchvision工具包主要包含以下三个部分:
- models:提供深度学习中各种经典网络的网络结构和预训练好的模型,包括ResNet系列等。
- datasets:提供常用的数据集加载,设计上继承torch.utils.data.Dataset,主要包括MNIST等数据集。同时datasets下包含这个ImageFolder方法,这个方法的实现和 博主这篇文章代码中的DogCat类(点击打开文章网页) 很相似,可以用来读取用户自己的图像数据集,而非datasets自带的数据集。
- transforms:提供常用数据预处理操作,主要包括对Tensor和PIL Image 对象的操作。
二、torchvision安装
基于Pytorch中安装torchvision简单详细完整版:点击打开文章网页
三、应用要求和实现流程及注意事项
(1)应用要求:主要是对MNIST数据集图片进行处理,首先自定义操作transforms,然后对每批次图像进行transforms处理,再将该批次的图像拼接成一张网格图像,再保存展示图像。
(2)具体实现流程:按下面代码的括号中的的顺序和注释依次进行理解。
(3)注意事项:注意transforms里面输入图像数据的通道数和设计是否匹配;每批次处理的图像数据的数目大小要明确;对象的迭代结果依旧是对象;tensor数据格式和PIL Image格式的转换。
四、代码及结果
import torch from torchvision import datasets import torchvision.transforms as T from torch.utils.data import DataLoader import numpy as np from torchvision.utils import make_grid,save_image transform = T.Compose([ # 等同于sequential,调用方式也一致,此transforms输入数据类型是PIL Image,输出数据类型是tensor (2) T.Resize(224), # 缩放图片,保持长宽比不变,最短边为224像素 T.CenterCrop(224), # 从图片中间切出224*224的图片 T.ToTensor(), # 将图片(Image)转换成Tensor,归一化[0,1] T.Normalize(mean=[.5],std=[.5]) # ,注意通道数的变化,此时输入数据的通道数为1,数据维度要跟着变化,标准化[-1,1],处理后格式依旧为tensor格式 ]) # torchvision.datasets提供常用数据集下载 # root指定数据集下载的路径,若之前没有下载程序会进行自动下载,train=False获取数据集中的测试数据集 (1) dataset = datasets.MNIST('data/',download=True,train=False,transform=transform) # print(dataset.data.type()) # 验证dataset对象的数据data属性是tensor # 加载之前获取的数据集dataset进行批次处理,生成一个可迭代的对象 dataload = DataLoader(dataset,shuffle=True,batch_size=16) # 打乱顺序,每批次数据取16个图像(3) # iter是用来每一次返回指定对象的迭代器 dataiter = iter(dataload) (4) # 迭代器迭代一次后用next方法返回一个迭代对象,对象的数据属性data也就是下面的next(dataiter)[0]是tensor数据格式,将每批次图像拼接成4*4网格图片,且通道数量可转成3通道数,make_grad返回tensor格式数据 img = make_grid(next(dataiter)[0],4) # (5) # 保存图片,输入要求tensor格式数据 save_image(img,'number.png') # (6) # 展示图片,可以先转成PILImage图像格式 img = T.ToPILImage()(img) # (7) img.show() # (8)