1.迁移学习的概念
迁移学习的概念就是其实我们不必去重新的训练一个网络,而是我们可以基于其他的网络,借用这个网络的权重,然后稍微的去修改少层数的权重,从而达到一个比较好的效果。
常见的迁移学习方式:
- 载入权重后训练所有参数
- 载入权重后只训练最后几层参数
- 载入权重后在原网络基础上再添加一层全连接层,仅训练最后一个全连接层
在pytorch中,含有很多网络结构的预处理模型,这些就是迁移学习的基础。
对于【14】自定义宝可梦数据集节中,实现的自定义数据集,如果我们选择自己写的ResNet18/50网络结构去训练(详情见【15】ResNet结构的pytorch实现),以10个epoch为例,最高的准确度acc只有80%左右。但是,如果使用迁移学习的方法,以30个epoch为例,最高的准确度可以达到0.974,测试集准确也有0.94。
2.迁移学习的实现
1)自定义模型结构参考代码
train.py
import torch import torchvision from torch import nn, optim from torch.utils.data import DataLoader from Pokemon import Pokemon from model import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 epoch_size = 5 learning_rate = 1e-3 batch_size = 32 resize = 224 root = 'E:\学习\机器学习\数据集\pokemon' train_data = Pokemon(root=root, resize=resize, mode='train') val_data = Pokemon(root=root, resize=resize, mode='val') test_data = Pokemon(root=root, resize=resize, mode='test') train_loader = DataLoader(train_data, batch_size, shuffle=True) val_loader = DataLoader(val_data, batch_size, shuffle=True) test_loader = DataLoader(test_data, batch_size, shuffle=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # model = ResNet50().to(device) model = ResNet18() print(model) crition = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) best_acc = 0 best_epoch = 0 for epoch in range(epoch_size): # 训练集训练 model.train() for batchidx, (image, label) in enumerate(test_loader): # image = image.to(device) # label = label.to(device) logits = model(image) loss = crition(logits, label) optimizer.zero_grad() loss.backward() optimizer.step() if batchidx%2 == 0: print("epoch:{}/{}, batch:{}/{}, loss:{}" .format(epoch+1, epoch_size, batchidx, len(test_loader), loss)) # 测试集挑选 model.eval() correct = 0 for image, label in val_loader: # image = image.to(device) # label = label.to(device) with torch.no_grad(): logits = model(image) pred = logits.argmax(dim=1) correct += torch.eq(pred, label).sum().float().item() acc = correct/len(val_data) print("epoch:{}, acc:{}".format(epoch+1, acc)) if acc > best_acc: best_acc = acc best_epoch = epoch torch.save(model.state_dict(), 'best.mdl') print("[get best epoch]- best_acc:{}, best_epoch:{}".format(best_acc, best_epoch))
test.py
import torch from torch.utils.data import DataLoader from Pokemon import Pokemon from model import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 epoch_size = 5 learning_rate = 1e-3 batch_size = 32 resize = 224 root = 'E:\学习\机器学习\数据集\pokemon' test_data = Pokemon(root=root, resize=resize, mode='test') test_loader = DataLoader(test_data, batch_size, shuffle=True) model = ResNet18() model.load_state_dict(torch.load('best.mdl')) # 测试集验证 correct = 0 for image, label in test_loader: with torch.no_grad(): logits = model(image) pred = logits.argmax(dim=1) correct += torch.eq(pred, label).sum().float().item() print("len(test_loader):", len(test_data)) acc = correct/len(test_data) print("final acc:", acc)
2)迁移学习模型结构参考代码
(大多数的代码是相同的,主要是模型定义部分的改变)
from torchvision.models import resnet18 from utils import Flatten # 迁移学习的主要实现 # model = ResNet18() trained_model = resnet18(pretrained=True) model = nn.Sequential(*list(trained_model.children())[:-1], # torch.Size([32, 512, 1, 1]) Flatten(), # torch.Size([32, 512]) nn.Linear(512, 5) # torch.Size([32, 5]) ) model.load_state_dict(torch.load('best.mdl'))
utils.py
from matplotlib import pyplot as plt import torch from torch import nn # 打平操作 class Flatten(nn.Module): def __init__(self): super(Flatten, self).__init__() def forward(self, x): shape = torch.prod(torch.tensor(x.shape[1:])).item() return x.view(-1, shape) # 显示图像 def plot_image(img, label, name): fig = plt.figure() for i in range(6): plt.subplot(2, 3, i + 1) plt.tight_layout() plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none') plt.title("{}: {}".format(name, label[i].item())) plt.xticks([]) plt.yticks([]) plt.show()
对于之前的自定义结构,只需要稍微改变了几行代码,准确率就有了大大的提升,验证集也达到了0.94的效果。
所以,为了提高模型的准确率,可以使用一下迁移学习的方式。