Pytorch基于迁移学习的VGG卷积神经网络-手撕(可直接运行)-部分地方不懂的可以参考我上一篇手撕VGG神经网络的注释 两个基本一样 只是这个网络是迁移过来的

简介: Pytorch基于迁移学习的VGG卷积神经网络-手撕(可直接运行)-部分地方不懂的可以参考我上一篇手撕VGG神经网络的注释 两个基本一样 只是这个网络是迁移过来的
import torch
import torchvision
import torchvision.models
from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import transforms
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
toPIL = transforms.ToPILImage()  # 将图像数据转换为PIL格式
trans = transforms.Compose([transforms.Resize((120, 120)),  # 将图像统一调整为120*120大小
                            transforms.ToTensor()])  # 将图像数据转换为张量
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, download=True,  # 导入CIFAR10数据集的训练集
                                          transform=trans)
traindata = DataLoader(dataset=train_data, batch_size=32, shuffle=True, num_workers=0)  # 将训练数据以每次32张图片的形式抽出进行训练
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, download=False,  # 导入CIFAR10数据集的测试集
                                         transform=trans)
train_size = len(train_data)  # 训练集的长度
test_size = len(test_data)  # 测试集的长度
print(train_size)
print(test_size)
testdata = DataLoader(dataset=test_data, batch_size=32, shuffle=True, num_workers=0)  # 将训练数据以每次32张图片的形式抽出进行测试
alexnet1 = torchvision.models.vgg16(pretrained = True)   #下载预训练模型
alexnet1.add_module("linear",nn.Linear(1000 , 10))  #在预训练模型的最后一层再加上一层全连接层进行训练微调,因为本数据集是10种 而且与训练模型都是在imagenet数据集上训练的 是1000种的输出
test1 = torch.ones(64, 3, 120, 120)  # 测试一下输出的形状大小
#其他地方跟alexnet的代码一样
test1 = alexnet1(test1)
print(test1.shape)
epoch = 2  # 迭代次数
learning = 0.0001  # 学习率
optimizer = torch.optim.Adam(alexnet1.parameters(), lr=learning)  # 使用Adam优化器
loss = nn.CrossEntropyLoss()  # 损失计算方式,交叉熵
train_loss_all = []  # 存放训练集损失的数组
train_accur_all = []  # 存放训练集准确率的数组
test_loss_all = []  # 存放测试集损失的数组
test_accur_all = []  # 存放测试集准确率的数组
for i in range(epoch):
    train_loss = 0
    train_num = 0.0
    train_accuracy = 0.0
    alexnet1.train()
    train_bar = tqdm(traindata)
    for step, data in enumerate(train_bar):
        img, target = data
        optimizer.zero_grad()  # 清空历史梯度
        outputs = alexnet1(img)  # 将图片打入网络进行训练
        loss1 = loss(outputs, target)
        outputs = torch.argmax(outputs, 1)
        loss1.backward()
        optimizer.step()
        train_loss += abs(loss1.item()) * img.size(0)
        accuracy = torch.sum(outputs == target)
        train_accuracy = train_accuracy + accuracy
        train_num += img.size(0)
    print("epoch:{} , train-Loss:{} , train-accuracy:{}".format(i + 1, train_loss / train_num,
                                                                train_accuracy / train_num))
    train_loss_all.append(train_loss / train_num)
    train_accur_all.append(train_accuracy.double().item() / train_num)
    test_loss = 0
    test_accuracy = 0.0
    test_num = 0
    alexnet1.eval()
    with torch.no_grad():
        test_bar = tqdm(testdata)
        for data in test_bar:
            img, target = data
            outputs = alexnet1(img)
            loss2 = loss(outputs, target)
            outputs = torch.argmax(outputs, 1)
            test_loss = test_loss + abs(loss2.item()) * img.size(0)
            accuracy = torch.sum(outputs == target)
            test_accuracy = test_accuracy + accuracy
            test_num += img.size(0)
    print("test-Loss:{} , test-accuracy:{}".format(test_loss / test_num, test_accuracy / test_num))
    test_loss_all.append(test_loss / test_num)
    test_accur_all.append(test_accuracy.double().item() / test_num)
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(range(epoch), train_loss_all,
         "ro-", label="Train loss")
plt.plot(range(epoch), test_loss_all,
         "bs-", label="test loss")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.subplot(1, 2, 2)
plt.plot(range(epoch), train_accur_all,
         "ro-", label="Train accur")
plt.plot(range(epoch), test_accur_all,
         "bs-", label="test accur")
plt.xlabel("epoch")
plt.ylabel("acc")
plt.legend()
plt.show()
torch.save(alexnet1, "xiaozhai.pth")
print("模型已保存")
相关文章
|
4天前
|
NoSQL Java Redis
Redis系列学习文章分享---第十八篇(Redis原理篇--网络模型,通讯协议,内存回收)
Redis系列学习文章分享---第十八篇(Redis原理篇--网络模型,通讯协议,内存回收)
12 0
|
4天前
|
存储 消息中间件 缓存
Redis系列学习文章分享---第十七篇(Redis原理篇--数据结构,网络模型)
Redis系列学习文章分享---第十七篇(Redis原理篇--数据结构,网络模型)
12 0
|
6天前
|
网络协议
计算机网络学习记录 运输层 Day5(2)
计算机网络学习记录 运输层 Day5(2)
10 1
|
6天前
计算机网络学习记录 应用层 Day6(2)
计算机网络学习记录 应用层 Day6(2)
9 0
|
6天前
|
网络协议
计算机网络学习记录 应用层 Day6(1)
计算机网络学习记录 应用层 Day6(1)
8 0
|
6天前
|
网络协议 算法 网络性能优化
计算机网络学习记录 运输层 Day5(1)
计算机网络学习记录 运输层 Day5(1)
8 0
|
6天前
|
网络虚拟化 网络架构
计算机网络学习记录 网络层 Day4(下)(2)
计算机网络学习记录 网络层 Day4(下)(2)
9 0
|
6天前
|
算法 网络协议 网络架构
计算机网络学习记录 网络层 Day4(下)(1)
计算机网络学习记录 网络层 Day4(下)(1)
9 0
|
6天前
计算机网络学习记录 网络层 Day4(上)(2)
计算机网络学习记录 网络层 Day4(上)(2)
6 0
|
6天前
|
网络协议 网络虚拟化 网络架构
计算机网络学习记录 数据链路层 Day3 (下)(2)
计算机网络学习记录 数据链路层 Day3 (下)(2)
7 0

热门文章

最新文章