【15】宝可梦数据集基于迁移学习训练

简介: 【15】宝可梦数据集基于迁移学习训练

1.迁移学习的概念


迁移学习的概念就是其实我们不必去重新的训练一个网络,而是我们可以基于其他的网络,借用这个网络的权重,然后稍微的去修改少层数的权重,从而达到一个比较好的效果。

image.png


常见的迁移学习方式:

  1. 载入权重后训练所有参数
  2. 载入权重后只训练最后几层参数
  3. 载入权重后在原网络基础上再添加一层全连接层,仅训练最后一个全连接层


在pytorch中,含有很多网络结构的预处理模型,这些就是迁移学习的基础。

image.png

对于【14】自定义宝可梦数据集节中,实现的自定义数据集,如果我们选择自己写的ResNet18/50网络结构去训练(详情见【15】ResNet结构的pytorch实现),以10个epoch为例,最高的准确度acc只有80%左右。但是,如果使用迁移学习的方法,以30个epoch为例,最高的准确度可以达到0.974,测试集准确也有0.94。


2.迁移学习的实现


1)自定义模型结构参考代码

train.py


import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from Pokemon import Pokemon
from model import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
epoch_size = 5
learning_rate = 1e-3
batch_size = 32
resize = 224
root = 'E:\学习\机器学习\数据集\pokemon'
train_data = Pokemon(root=root, resize=resize, mode='train')
val_data = Pokemon(root=root, resize=resize, mode='val')
test_data = Pokemon(root=root, resize=resize, mode='test')
train_loader = DataLoader(train_data, batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = ResNet50().to(device)
model = ResNet18()
print(model)
crition = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_acc = 0
best_epoch = 0
for epoch in range(epoch_size):
    # 训练集训练
    model.train()
    for batchidx, (image, label) in enumerate(test_loader):
        # image = image.to(device)
        # label = label.to(device)
        logits = model(image)
        loss = crition(logits, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batchidx%2 == 0:
            print("epoch:{}/{}, batch:{}/{}, loss:{}"
                  .format(epoch+1, epoch_size, batchidx, len(test_loader), loss))
    # 测试集挑选
    model.eval()
    correct = 0
    for image, label in val_loader:
        # image = image.to(device)
        # label = label.to(device)
        with torch.no_grad():
            logits = model(image)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, label).sum().float().item()
    acc = correct/len(val_data)
    print("epoch:{}, acc:{}".format(epoch+1, acc))
    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch
        torch.save(model.state_dict(), 'best.mdl')
        print("[get best epoch]- best_acc:{}, best_epoch:{}".format(best_acc, best_epoch))


test.py


import torch
from torch.utils.data import DataLoader
from Pokemon import Pokemon
from model import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
epoch_size = 5
learning_rate = 1e-3
batch_size = 32
resize = 224
root = 'E:\学习\机器学习\数据集\pokemon'
test_data = Pokemon(root=root, resize=resize, mode='test')
test_loader = DataLoader(test_data, batch_size, shuffle=True)
model = ResNet18()
model.load_state_dict(torch.load('best.mdl'))
# 测试集验证
correct = 0
for image, label in test_loader:
    with torch.no_grad():
        logits = model(image)
        pred = logits.argmax(dim=1)
    correct += torch.eq(pred, label).sum().float().item()
print("len(test_loader):", len(test_data))
acc = correct/len(test_data)
print("final acc:", acc)


2)迁移学习模型结构参考代码

(大多数的代码是相同的,主要是模型定义部分的改变)


from torchvision.models import resnet18
from utils import Flatten
# 迁移学习的主要实现
# model = ResNet18()
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],  # torch.Size([32, 512, 1, 1])
                      Flatten(),          # torch.Size([32, 512])
                      nn.Linear(512, 5)   # torch.Size([32, 5])
                      )
model.load_state_dict(torch.load('best.mdl'))


utils.py


from    matplotlib import pyplot as plt
import  torch
from    torch import nn
# 打平操作
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)
# 显示图像
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


对于之前的自定义结构,只需要稍微改变了几行代码,准确率就有了大大的提升,验证集也达到了0.94的效果。


所以,为了提高模型的准确率,可以使用一下迁移学习的方式。


目录
相关文章
|
存储 机器学习/深度学习 算法
MMDetection3d对KITT数据集的训练与评估介绍
MMDetection3d对KITT数据集的训练与评估介绍
2087 0
MMDetection3d对KITT数据集的训练与评估介绍
|
5月前
|
Python
模型训练
【8月更文挑战第20天】模型训练。
62 0
|
4月前
|
人工智能 自动驾驶 数据库
领域大模型的训练需要什么数据?
领域大模型的训练需要什么数据?
236 0
|
5月前
|
机器学习/深度学习 自然语言处理 数据可视化
训练模型
【8月更文挑战第1天】
58 2
|
XML 数据挖掘 数据格式
|
网络安全 开发工具 网络架构
YOLOV7详细解读(四)训练自己的数据集
YOLOV7详细解读(四)训练自己的数据集
812 0
|
算法 搜索推荐
每日训练(二)
每日训练(二),题目来源:力扣,PTA。
每日训练(二)
每日训练(一)
题目来源于PTA基础编程和力扣剑指offer
每日训练(一)

热门文章

最新文章