实验环境
torch1.8.0+torchvision0.9.0
import torch import torchvision print(torch.__version__) print(torchvision.__version__)
1.8.0 0.9.0+cpu
1.PyTorch数据加载
import torchvision.transforms as tfm from PIL import Image img = Image.open('volleyball.png') img_1 = tfm.RandomCrop(200, padding=50)(img) #随机裁剪图片 img_1.show() img_1.save('crop.png') img_2 = tfm.RandomHorizontalFlip()(img) #随机水平翻转图片 img_2.show() img_2.save('flip.png')
1.1 数据预处理
torchvision.transforms
transfrom_train = tfm.Compose([ tfm.RandomCrop(32, padding=4), tfm.RandomHorizontalFlip(), tfm.ToTensor(), #将图片转换为Tensor张量 tfm.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)) #标准化 ])
1.2 数据加载
torch.utils.data
loader = torch.utils.data.DataLoader( datasets, batch_size=32, shuffle=True, sampler=None, num_workers=2, collate_fn=None, pin_memory=True, drop_last=False )
datasets:传入的数据集,可以是自定义的dataset对象或者torchvision中的预定义数据集对象。
batch_size:每个batch中包含的样本数量。
shuffle:是否打乱数据集。
sampler:样本抽样器,如果指定了sampler,则忽略shuffle参数。
num_workers:用于数据加载的子进程数量。
collate_fn:对样本进行批处理前的预处理函数,可用于对样本进行排序、padding等操作。
pin_memory:是否将数据加载到GPU的显存中。
drop_last:如果数据集样本数量不能被batch_size整除,则是否舍弃剩余的不足一个batch的样本。
2.PyTorch模型搭建
2.1 经典模型
torchvision.models
from torchvision import models net1 = models.resnet50() net2 = models.resnet50(pretrained=True)
2.2 模型加载与保存
model.load_state_dict(torch.load('pretrained_weights.pth')) torch.save(model.state_dict(), 'model_weights.pth')
3.PyTorch优化器
3.1 torch.optim
optimizer = optim.SGD([ #SGD随机梯度下降算法 {'params':model.base.parameters()}, {'params':model.classifier.parameters(), 'lr': 1e-3} ], lr=1e-2, momentum=0.9)
# 训练过程 model = init_model_function() #模型构建 optimizer = optim.SomeOptimizer( #设置优化器 model.parameters(), lr, mm ) for data, label in train_dataloader: optimizer.zero_grad() #前向计算前,清空原有梯度 output = model(data) #前向计算 loss = loss_function(output, label) #损失函数 loss.backward() #反向传播 optimizer.step() #更新参数
3.2 学习率调整
scheduler = optim.lr_scheduler.SomeScheduler(optimizer, *args) for epoch in range(epochs): train() test() scheduler.step()
常见函数
激活单元类型
ELU | MultiheadAttention | SELU | softshrink | Softmin |
Hardshrink | PReLU | CELU | Softsign | Softmax |
Hardtanh | ReLU | GELU | Tanh | Softmax2d |
LeakyReLU | ReLU6 | Sigmoid | Tanhshrink | LogSoftmax |
LogSigmoid | RReLU | Softplus | Threshold |
损失函数层类型
L1Loss | PoissonNLLLoss | HingeEmbeddingLoss | CosineEmbeddingLoss |
MSELoss | KLDivLoss | MultiLabelMarginLoss | MultiMarginLoss |
CrossEntropyLoss | BCELoss | SmoothL1Loss | TripletMarginLoss |
CTCLoss | BCEWithLogitsLoss | SoftMarginLoss | |
NLLLoss | MarginRankingLoss | MultiLabelSoftMarginLoss |
优化器类型
Adadelta | AdamW | ASGD | Rprop |
Adagrad | SparseAdam | LBFGS | SGD |
Adam | Adamax | RMSprop |
变换操作类型
Compose | RandomAffine | RandomOrder | Resize | ToTensor |
CenterCrop | RandomApply | RandomPerspective | Scale | Lambda |
ColorJitter | RandomChoice | RandomResizedCrop | TenCrop | |
FiveCrop | RandomCrop | RandomRotation | LinearTransformation |
Grayscale | RandomGrayscale | RandomSizedCrop | Normalize | |
Pad | RandomHorizontalFlip | RandomVerticalFlip | ToPILImage |
数据集名称
MNIST | CocoCaptions | CIFAR10 | Flickr8k | USPS |
FashionMNIST | cocoDetection | CIFAR100 | Flickr30k | Kinetics400 |
KMNIST | LSUN | STL10 | VOCSegmentation | HMDB51 |
EMNIST | ImageFolder | SVHN | VOCDetection | UCF101 |
QMNIST | DatasetFolder | PhotoTour | Cityscape | CelebA |
FakeData | ImageNet | SBU | SBDataset |
torchvision.models中所有实现的分类模型
AlexNet | VGG-13-bn | ResNet-101 | Densenet-201 | ResNeXt-50-32x4d |
VGG-11 | VGG-16-bn | ResNet-152 | Densenet-161 | ResNeXt-101-32x8d |
VGG-13 | VGG-19-bn | SqueezeNet | Inception-V3 | Wide ResNet-50-2 |
VGG-16 | ResNet-18 | GoogleNet | Wide ResNet-101-2 |
VGG-19 | ResNet-34 | Densenet-121 | ShuffleNet-V2 | MNASNet 1.0 |
VGG-11-bn | ResNet-50 | Densenet-169 | MobileNet-V2 |