3.PyTorch实现迁移学习
文件目录
3.1数据集预处理
dataset.py
from torchvision import datasets, transforms import torch train=transforms.Compose([ transforms.RandomResizedCrop(224), # 随机裁剪一个area然后再resize transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) val=transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) trainset=datasets.ImageFolder(root='hymenoptera_data/train',transform=train) valset=datasets.ImageFolder(root='hymenoptera_data/val',transform=val) trainloader=torch.utils.data.DataLoader(trainset,batch_size=4, shuffle=True, num_workers=4) valloader=torch.utils.data.DataLoader(valset,batch_size=4, shuffle=True, num_workers=4)
3.2构建模型
model.py
from torchvision import models import torch.nn as nn #初始化模型 #保证模型不改变的层的参数,不发生梯度变化 def set_parameter_requires_grad(model, feature_extracting): if feature_extracting: for param in model.parameters(): param.requires_grad = False def initialize_model(model_name, num_classes, feature_extract): model_ft=None input_size=0 if model_name =='resnet': #resnet18 model_ft = models.resnet18(pretrained=True) set_parameter_requires_grad(model_ft, feature_extract) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, num_classes) input_size = 224 elif model_name == "alexnet": model_ft = models.alexnet(pretrained=True) set_parameter_requires_grad(model_ft, feature_extract) num_ftrs = model_ft.classifier[6].in_features model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes) input_size = 224 elif model_name == "vgg": #vgg11 model_ft = models.vgg11_bn(pretrained=True) set_parameter_requires_grad(model_ft, feature_extract) num_ftrs = model_ft.classifier[6].in_features model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes) input_size = 224 elif model_name == "squeezenet": model_ft = models.squeezenet1_0(pretrained=True) set_parameter_requires_grad(model_ft, feature_extract) model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) model_ft.num_classes = num_classes input_size = 224 elif model_name == "densenet": model_ft = models.densenet121(pretrained=True) set_parameter_requires_grad(model_ft, feature_extract) num_ftrs = model_ft.classifier.in_features model_ft.classifier = nn.Linear(num_ftrs, num_classes) input_size = 224 elif model_name == "inception": model_ft = models.inception_v3(pretrained=True) set_parameter_requires_grad(model_ft, feature_extract) # Handle the auxilary net num_ftrs = model_ft.AuxLogits.fc.in_features model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) # Handle the primary net num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, num_classes) input_size = 299 else: print("没有合适的模型...") return model_ft, input_size
3.3模型训练与验证
run.py
from __future__ import print_function from __future__ import division import torch.nn as nn import torch.optim as optim from model import initialize_model from torch.optim import lr_scheduler import time import copy from dataset import * import argparse parser=argparse.ArgumentParser() #模型选择 parser.add_argument('-m','--model_name',type=str,choices=['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],help="input model_name",default='resnet') #分类类别数 parser.add_argument('-n','--num_classes',type=int,help="input num_classes",default=2) #定义一个批次的样本数 parser.add_argument('-b','--batch_size',type=int,help="input batch_size",default=8) #定义迭代批次 parser.add_argument('-e','--num_epochs',type=int,help="input num_epochs",default=25) args=parser.parse_args() #用于特征提取的标志。如果为False,则对整个模型进行微调, #如果为True,则仅更新重塑的图层参数 feature_extract = True #定义数据字典 datasets={train:trainset,val:valset} #定义数据集字典 dataloaders={train:trainloader,val:valloader} model_ft, input_size = initialize_model(args.model_name, args.num_classes, feature_extract) criterion = nn.CrossEntropyLoss() # 观察所有参数都正在优化 optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) # 每7个epochs衰减LR通过设置gamma=0.1 exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) def train_model(model,criterion,optimizer,scheduler,num_epochs): since=time.time() val_acc_history = [] #获取模型初始参数 best_model_wts=copy.deepcopy(model.state_dict()) best_acc=0.0 for epoch in range(num_epochs): print('Epoch {}/{}'.format(epoch,num_epochs-1)) print('-'*10) for data in ['train','val']: if data=='train': scheduler.step() model.train() else: model.eval() running_loss = 0.0 running_corrects = 0 for inputs,labels in dataloaders[data]: optimizer.zero_grad() with torch.set_grad_enabled(data=='train'): outputs=model(inputs) _,preds=torch.max(outputs,1) loss=criterion(outputs,labels) if data=='train': loss.backward() optimizer.step() running_loss+=loss.item()*inputs.size(0) running_corrects += torch.sum(preds == labels.data) epoch_loss = running_loss / len(datasets[data]) epoch_acc = running_corrects.double() / len(datasets[data]) print('{} Loss: {:.4f} Acc: {:.4f}'.format( data, epoch_loss, epoch_acc)) # 深度复制mo if data=='val' and epoch_acc > best_acc: best_acc = epoch_acc best_model_wts = copy.deepcopy(model.state_dict()) print() time_elapsed = time.time() - since print('Training complete in {:.0f}m {:.0f}s'.format( time_elapsed // 60, time_elapsed % 60)) print('Best val Acc: {:4f}'.format(best_acc)) model.load_state_dict(best_model_wts) return model train_model(model_ft,criterion, optimizer_ft, exp_lr_scheduler,args.num_epochs)