CIFAR10
通过这个代码会把数据集自动下载到root路径,然后通过root路径获取到训练和测试的数据集,再结合网络模型进行训练。
import torch
import torchvision
import torchvision.transforms as transforms
"""加载CIFAR10"""
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
"""如果要进行其他的数据增强则只需要改动transform里面的代码即可"""
# data_transform = {
# "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪,在缩放成224*224
# transforms.RandomHorizontalFlip(), # 水平方向随机翻转,概率为0.5
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
# "test": transforms.Compose([transforms.Resize(256),
# transforms.CenterCrop(224),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
trainset = torchvision.datasets.CIFAR10(root='E:\dataset\cifar_10', train=True, download=True, transform=transform) # data_transform["train"]
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='E:\dataset\cifar_10', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print(trainloader)
print(testloader)
print(testloader)
print(classes)
输出如下: