【15】宝可梦数据集基于迁移学习训练

简介: 【15】宝可梦数据集基于迁移学习训练

1.迁移学习的概念


迁移学习的概念就是其实我们不必去重新的训练一个网络,而是我们可以基于其他的网络,借用这个网络的权重,然后稍微的去修改少层数的权重,从而达到一个比较好的效果。

image.png


常见的迁移学习方式:

  1. 载入权重后训练所有参数
  2. 载入权重后只训练最后几层参数
  3. 载入权重后在原网络基础上再添加一层全连接层,仅训练最后一个全连接层


在pytorch中,含有很多网络结构的预处理模型,这些就是迁移学习的基础。

image.png

对于【14】自定义宝可梦数据集节中,实现的自定义数据集,如果我们选择自己写的ResNet18/50网络结构去训练(详情见【15】ResNet结构的pytorch实现),以10个epoch为例,最高的准确度acc只有80%左右。但是,如果使用迁移学习的方法,以30个epoch为例,最高的准确度可以达到0.974,测试集准确也有0.94。


2.迁移学习的实现


1)自定义模型结构参考代码

train.py


import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from Pokemon import Pokemon
from model import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
epoch_size = 5
learning_rate = 1e-3
batch_size = 32
resize = 224
root = 'E:\学习\机器学习\数据集\pokemon'
train_data = Pokemon(root=root, resize=resize, mode='train')
val_data = Pokemon(root=root, resize=resize, mode='val')
test_data = Pokemon(root=root, resize=resize, mode='test')
train_loader = DataLoader(train_data, batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = ResNet50().to(device)
model = ResNet18()
print(model)
crition = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
best_acc = 0
best_epoch = 0
for epoch in range(epoch_size):
    # 训练集训练
    model.train()
    for batchidx, (image, label) in enumerate(test_loader):
        # image = image.to(device)
        # label = label.to(device)
        logits = model(image)
        loss = crition(logits, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batchidx%2 == 0:
            print("epoch:{}/{}, batch:{}/{}, loss:{}"
                  .format(epoch+1, epoch_size, batchidx, len(test_loader), loss))
    # 测试集挑选
    model.eval()
    correct = 0
    for image, label in val_loader:
        # image = image.to(device)
        # label = label.to(device)
        with torch.no_grad():
            logits = model(image)
            pred = logits.argmax(dim=1)
        correct += torch.eq(pred, label).sum().float().item()
    acc = correct/len(val_data)
    print("epoch:{}, acc:{}".format(epoch+1, acc))
    if acc > best_acc:
        best_acc = acc
        best_epoch = epoch
        torch.save(model.state_dict(), 'best.mdl')
        print("[get best epoch]- best_acc:{}, best_epoch:{}".format(best_acc, best_epoch))


test.py


import torch
from torch.utils.data import DataLoader
from Pokemon import Pokemon
from model import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
epoch_size = 5
learning_rate = 1e-3
batch_size = 32
resize = 224
root = 'E:\学习\机器学习\数据集\pokemon'
test_data = Pokemon(root=root, resize=resize, mode='test')
test_loader = DataLoader(test_data, batch_size, shuffle=True)
model = ResNet18()
model.load_state_dict(torch.load('best.mdl'))
# 测试集验证
correct = 0
for image, label in test_loader:
    with torch.no_grad():
        logits = model(image)
        pred = logits.argmax(dim=1)
    correct += torch.eq(pred, label).sum().float().item()
print("len(test_loader):", len(test_data))
acc = correct/len(test_data)
print("final acc:", acc)


2)迁移学习模型结构参考代码

(大多数的代码是相同的,主要是模型定义部分的改变)


from torchvision.models import resnet18
from utils import Flatten
# 迁移学习的主要实现
# model = ResNet18()
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],  # torch.Size([32, 512, 1, 1])
                      Flatten(),          # torch.Size([32, 512])
                      nn.Linear(512, 5)   # torch.Size([32, 5])
                      )
model.load_state_dict(torch.load('best.mdl'))


utils.py


from    matplotlib import pyplot as plt
import  torch
from    torch import nn
# 打平操作
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)
# 显示图像
def plot_image(img, label, name):
    fig = plt.figure()
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()


对于之前的自定义结构,只需要稍微改变了几行代码,准确率就有了大大的提升,验证集也达到了0.94的效果。


所以,为了提高模型的准确率,可以使用一下迁移学习的方式。


目录
相关文章
|
4月前
|
机器学习/深度学习 弹性计算 TensorFlow
在阿里云上打造强大的模型训练服务
随着人工智能技术的迅猛发展,模型训练服务变得愈发关键。阿里云提供了一系列强大的产品,使得在云端轻松搭建、优化和管理模型训练变得更加便捷。本文将详细介绍如何使用阿里云的相关产品构建高效的模型训练服务。
194 0
|
9天前
|
机器学习/深度学习 编解码 算法
目标检测舰船数据集整合
目标检测舰船数据集整合
|
8月前
|
XML 数据挖掘 数据格式
|
5月前
|
传感器 数据采集 编解码
3D目标检测数据集 DAIR-V2X-V
本文分享国内场景3D目标检测,公开数据集 DAIR-V2X-V(也称为DAIR-V2X车端)。DAIR-V2X车端3D检测数据集是一个大规模车端多模态数据集,包括: 22325帧 图像数据 22325帧 点云数据 2D&3D标注 基于该数据集,可以进行车端3D目标检测任务研究,例如单目3D检测、点云3D检测和多模态3D检测。
102 0
|
8月前
|
机器学习/深度学习
使用卷积神经网络CNN训练minist数据集(二)
使用卷积神经网络CNN训练minist数据集(二)
|
机器学习/深度学习 编解码 算法
卷积神经网络分类算法的模型训练
卷积神经网络分类算法的模型训练
127 0
每日训练(一)
题目来源于PTA基础编程和力扣剑指offer
每日训练(一)
每日训练(五)
每日训练五,题目来源:牛客、力扣
每日训练(五)