【PyTorch】Torchvision

简介: 【PyTorch】Torchvision

三、Torchvision

PyTorch官网:https://pytorch.org

1、Dataset

数据集描述:https://www.cs.toronto.edu/~kriz/cifar.html

数据集使用说明:

CIFAR10数据集:https://pytorch.org/vision/stable/generated/torchvision.datasets.CIFAR10.html#torchvision.datasets.CIFAR10

参数说明:

  • root:数据集存放位置
  • train:True(训练集)、False(测试集)
  • transform:变化
  • target_transform:target变化
  • download:是否下载

基本使用:

import torchvision
train_set = torchvision.datasets.CIFAR10(root="../data", train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="../data", train=False, download=True)
print(test_set[0])
print(test_set.classes)
img, target = test_set[0]
print(img)
print(target)
print(test_set.classes[target])
img.show()
Files already downloaded and verified
Files already downloaded and verified
(<PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F0220>, 3)
['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
<PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F00D0>
3
cat

转为Tensor类型: 并使用TensorBoard显示

import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=False, download=True)
writer = SummaryWriter("logs")
for i in range(10):
    img, target = test_set[i]
    writer.add_image("test_set", img, i)
writer.close()

2、DataLoader

介绍:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader

参数说明:

  • batch_size:每批要加载多少个样品(默认:1)
  • shuffle:True(重新洗牌),(默认:False)
  • num_workers:使用多少个子进程来加载数据,(默认:0 表示主进程)
  • drop_last:是否舍去最后(除不尽的)

2.1 test_data

import torchvision
from torch.utils.data import DataLoader
# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
# 测试集第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
torch.Size([3, 32, 32]) # 3通道 32 * 32
3

2.2 test_loader

import torchvision
from torch.utils.data import DataLoader
# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False)
# 测试集第一张图片及target
# img, target = test_data[0]
# print(img.shape)
# print(target)
# test_loader
for data in test_loader:
    imgs, targets = data
    print(imgs.shape)
    print(targets)
torch.Size([4, 3, 32, 32]) # 4张 3通道 32 * 32
tensor([1, 2, 0, 8]) # 4张图片的target糅合在一起
...
...

注意:target[1, 2, 0, 8]并不是按序采样,而是随机的!

2.3 drop_last

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
# batch_size=64
writer = SummaryWriter("logs")
step = 0
for data in test_loader:
    imgs, targets = data
    writer.add_images("test_data", imgs, step)
    step += 1
writer.close()

注意:最后一次采样只有16张图像,这是因为参数drop_last=False

当不满足每一次都取一定值的图片时,可以显示真实剩下的或者直接舍去(drop_last=True)。

当我们设置为drop_last=True时,就会舍去最后一组采样:

2.4 shuffle

import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
# 准备测试集
test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True)
# shuffle=False
writer = SummaryWriter("logs")
for epoch in range(2):
    step = 0
    for data in test_loader:
        imgs, targets = data
        writer.add_images("Epoch:{}".format(epoch), imgs, step)
        step += 1
writer.close()

注意:两者采样完全相同,如果想要 “洗牌”,应设置shuffle=True

目录
相关文章
|
PyTorch 算法框架/工具 计算机视觉
【PyTorch】Torchvision Models
【PyTorch】Torchvision Models
288 0
|
数据可视化 PyTorch 算法框架/工具
Pytorch可视化Visdom、tensorboardX和Torchvision
Pytorch可视化Visdom、tensorboardX和Torchvision
104 0
|
6月前
|
PyTorch 算法框架/工具
win10下安装pytorch,torchvision遇到的bug
win10下安装pytorch,torchvision遇到的bug
|
自然语言处理 并行计算 PyTorch
基于Pytorch中安装torchvision简单详细完整版
基于Pytorch中安装torchvision简单详细完整版
2784 1
基于Pytorch中安装torchvision简单详细完整版
|
机器学习/深度学习 固态存储 PyTorch
pytorch中torchvision读取预训练模型
pytorch中torchvision读取预训练模型
208 0
pytorch中torchvision读取预训练模型
|
数据采集 机器学习/深度学习 PyTorch
Pytorch中基于MNIST数据的torchvision工具包应用
Pytorch中基于MNIST数据的torchvision工具包应用
145 0
Pytorch中基于MNIST数据的torchvision工具包应用
|
PyTorch 算法框架/工具 计算机视觉
Pytorch中torchvision包transforms模块应用小案例
Pytorch中torchvision包transforms模块应用小案例
169 0
Pytorch中torchvision包transforms模块应用小案例
|
PyTorch 算法框架/工具 Caffe
解决办法:KeyError: ‘ExpandBackward’及老版本pytorch/torchvision的安装办法。
解决办法:KeyError: ‘ExpandBackward’及老版本pytorch/torchvision的安装办法。
112 0
|
机器学习/深度学习 人工智能 并行计算
|
2月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
397 2