三、Torchvision
PyTorch官网:https://pytorch.org
1、Dataset
数据集描述:https://www.cs.toronto.edu/~kriz/cifar.html
数据集使用说明:
参数说明:
- root:数据集存放位置
- train:True(训练集)、False(测试集)
- transform:变化
- target_transform:target变化
- download:是否下载
基本使用:
import torchvision train_set = torchvision.datasets.CIFAR10(root="../data", train=True, download=True) test_set = torchvision.datasets.CIFAR10(root="../data", train=False, download=True) print(test_set[0]) print(test_set.classes) img, target = test_set[0] print(img) print(target) print(test_set.classes[target]) img.show()
Files already downloaded and verified Files already downloaded and verified (<PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F0220>, 3) ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] <PIL.Image.Image image mode=RGB size=32x32 at 0x23CD61F00D0> 3 cat
转为Tensor类型: 并使用TensorBoard显示
import torchvision from torch.utils.tensorboard import SummaryWriter dataset_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) train_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=True, download=True) test_set = torchvision.datasets.CIFAR10(root="../data", transform=dataset_transform, train=False, download=True) writer = SummaryWriter("logs") for i in range(10): img, target = test_set[i] writer.add_image("test_set", img, i) writer.close()
2、DataLoader
介绍:https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
参数说明:
- batch_size:每批要加载多少个样品(默认:1)
- shuffle:True(重新洗牌),(默认:False)
- num_workers:使用多少个子进程来加载数据,(默认:0 表示主进程)
- drop_last:是否舍去最后(除不尽的)
2.1 test_data
import torchvision from torch.utils.data import DataLoader # 准备测试集 test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor()) # 测试集第一张图片及target img, target = test_data[0] print(img.shape) print(target)
torch.Size([3, 32, 32]) # 3通道 32 * 32 3
2.2 test_loader
import torchvision from torch.utils.data import DataLoader # 准备测试集 test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, num_workers=0, drop_last=False) # 测试集第一张图片及target # img, target = test_data[0] # print(img.shape) # print(target) # test_loader for data in test_loader: imgs, targets = data print(imgs.shape) print(targets)
torch.Size([4, 3, 32, 32]) # 4张 3通道 32 * 32 tensor([1, 2, 0, 8]) # 4张图片的target糅合在一起 ... ...
注意:
target[1, 2, 0, 8]
并不是按序采样,而是随机的!
2.3 drop_last
import torchvision from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter # 准备测试集 test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False) # batch_size=64 writer = SummaryWriter("logs") step = 0 for data in test_loader: imgs, targets = data writer.add_images("test_data", imgs, step) step += 1 writer.close()
注意:最后一次采样只有16张图像,这是因为参数
drop_last=False
。当不满足每一次都取一定值的图片时,可以显示真实剩下的或者直接舍去(
drop_last=True
)。
当我们设置为drop_last=True
时,就会舍去最后一组采样:
2.4 shuffle
import torchvision from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter # 准备测试集 test_data = torchvision.datasets.CIFAR10("../data", train=False, transform=torchvision.transforms.ToTensor()) test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=False, num_workers=0, drop_last=True) # shuffle=False writer = SummaryWriter("logs") for epoch in range(2): step = 0 for data in test_loader: imgs, targets = data writer.add_images("Epoch:{}".format(epoch), imgs, step) step += 1 writer.close()
注意:两者采样完全相同,如果想要 “洗牌”,应设置
shuffle=True
。