基于Pytorch的深度学习模型保存和加载方式

简介: 基于Pytorch的深度学习模型保存和加载方式

我们在训练深度学习模型的过程中,最好对已经训练好的深度学习模型进行保存,或者方便的加载别人训练好的模型微调节省训练时间,实现高效率解决问题。

一、必要性

  • 深度学习的模型参数超级多比如:Transformer模型、Bert模型等
  • 训练的数据集一般很大,比如:1000G以上等
  • 若本地电脑的算力或者实验室的服务器算力基本不够,训练模型花费时间多,一个模型短则训练几天不能停,甚至几个月,若这时又发生内存不够等,那简直是晴天霹雳。
  • 总而言之,这时若有类似的训练好的模型可以直接拿来用然后微调是非常nice的,因此模型的保存是利己利人,有助于共建和谐社会。

二、保存模型的三种文件格式(任选一,作用上基本无区别)

  • .pt :这个后缀在官方文档里使用较多。
  • .pth :这个后缀一般大家觉得惯例使用这个。
  • .pkl:这个后缀是因为 Python 有一个序列化模块 pickle ,然后使用它保存模型时,通常会起一个以 .pkl为后缀名的文件。

三、保存模型的方法(注意:保存整个模型,而非仅仅保存模型的参数,包括模型结构)

import torch
torch.save(model, "文件绝对路径/模型文件名.pt") # 保存模型,model是深度学习模型,文件绝对路径/模型文件名.pt是保存训练好的模型的绝对路径和模型名称为模型文件名.pt
# 举例保存模型说明
from torch_geometric.data import Data
import mat4py
import scipy.sparse as sp
import random
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import DataLoader
import torch_geometric.nn as pyg_nn
import numpy as np
import warnings
from sklearn.metrics import accuracy_score
warnings.filterwarnings("ignore", category=Warning)
datas = mat4py.loadmat('J:/aidbBag.mat')  # 1600个图包(各图包含若干张图,每张图有20个节点及其属性,邻接矩阵和图编号)及其包标签
datas = datas['bags']  # 获取bags文件的内容,输出为list数据类型,行数为2,列数为1600,图包和包标签并行相邻
# 将a,b两个矩阵沿对角线方向斜着合并,空余处补0
def adjConcat(a, b):
    lena = len(a)
    lenb = len(b)
    left = np.row_stack((a, np.zeros((lenb, lena))))  # 先将a和一个len(b)*len(a)的零矩阵垂直拼接,得到左半边
    right = np.row_stack((np.zeros((lena, lenb)), b))  # 再将一个len(a)*len(b)的零矩阵和b垂直拼接,得到右半边
    result = np.hstack((left, right))  # 将左右矩阵水平拼接
    return result
# 对每个图包的数据进行预处理
dataset = []
for i in range(2):  # 行数
    for j in range(0, len(datas[i]), 2):  # 列数
        # 邻接矩阵数据预处理
        am = datas[i][j]['am']  # 图包中所有图的邻接矩阵
        # 将图包中所有图沿边角线连接拼接成一张超图matrix
        matrix = am[0]
        for w in range(len(am) - 1):
            matrix = adjConcat(matrix, am[w + 1])
            w += 1
        # 将邻接矩阵的超图转换为稀疏矩阵
        edge_index_temp = sp.coo_matrix(matrix)
        indices = np.vstack((edge_index_temp.row, edge_index_temp.col))
        edge_index = torch.LongTensor(indices)
        # 节点数据预处理
        nl = datas[i][j]['nl']  # 图包中所有图的各图的节点及其属性值,维度是[20,1]
        # 将图包中所有图的节点进行拼接
        for k in range(len(nl)):
            x = np.array(list(nl[k].values()))
            x = x.squeeze(0)
            node = torch.FloatTensor(x)
            if k > 0:
                nodes = torch.cat([nodes, node])
            else:
                nodes = node
        # 拼接成维度为[每张图片节点数20*图包中图片的数目,1]
        x = nodes
        # 图包标签预处理
        # 注意:图包标签和图包数据并行(hang)相邻
        j += 1
        if datas[i][j] == -1:
            data = Data(x=x, edge_index=edge_index, y=0)  # 构建新型data数据对象
        else:
            data = Data(x=x, edge_index=edge_index, y=1)  # 构建新型data数据对象
        # 图包标签整型数据转张量tensor,方便后面正确率结果对比
        data.y = np.array(data.y, dtype=np.float32)
        data.y = torch.LongTensor(data.y)
        # 构建数据集:为一张超图(图包中的图拼接成),图包中所有图片数目*20个节点,每个节点一个特征,Coo稀疏矩阵的边,一张超图一个超图(图包)标签
        dataset.append(data)  #将每个data数据对象加入列表
# 打乱数据集的数据
random.shuffle(dataset)
# 切分数据集,分成训练和测试两部分
train_dataset = dataset[:1600]
test_dataset = dataset[1400:1600]
# 构造模型类
class Net(torch.nn.Module):
    """构造GCN模型网络"""
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(1, 16) # 构造第一层,输入和输出通道,输入通道的大小和节点的特征维度一致
        self.conv2 = GCNConv(16, 2) # 构造第二层,输入和输出通道,输出通道的大小和图或者节点的分类数量一致,比如此程序中图标记就是二分类0和1,所以等于2
    def forward(self, data): # 前向传播
        x, edge_index, batch = data.x, data.edge_index, data.batch # 赋值
        # print(batch)
        # print(x)
        x = self.conv1(x, edge_index) # 第一层启动运算,输入为节点及特征和边的稀疏矩阵,输出结果是二维度[20张超图的所有节点数,16]
        # print(x.shape)
        x = F.relu(x) # 激活函数
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index) # 第二层启动运算,输入为节点及特征和边的稀疏矩阵,输出结果是二维度[20张超图的所有节点数,2]
        x = pyg_nn.global_max_pool(x, batch) # 池化降维,根据batch的值知道有多少张超图(每个超图的节点的分类值不同0-19),再将每张超图的节点取一个全局最大的节点作为该张超图的一个输出值
        # print(x.shape) # 输出维度变成[20,2]
        x = torch.FloatTensor(x)
        return x
# 使用GPU
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 构建模型实例
model = Net() # 构建模型实例
optimizer = torch.optim.Adam(model.parameters(), lr=0.005) # 优化器,模型参数优化计算
train_loader = DataLoader(train_dataset, batch_size=20, shuffle=False) # 加载训练数据集,训练数据中分成每批次20个超图data数据
loss_model = torch.nn.CrossEntropyLoss()
# print(len(train_dataset))
# 训练模型
model.train() # 表示模型开始训练,在使用pytorch构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是启用batch normalization和drop out。
for epoch in range(20): # 训练所有训练数据集100次
    loss_all = 0
    # 一轮epoch优化的内容
    for data in train_loader: # 每次提取训练数据集一批20张超图数据赋值给data
        # print(data)
        # data是batch_size图片的大小
        # print(data.edge_index)
        # print(data.batch.shape)
        # print(data.x.shape)
        optimizer.zero_grad() # 梯度清零
        output = model(data) # 前向传播,把一批训练数据集导入模型并返回输出结果
        label = data.y # 20张超图数据的标签集合
        # print(data.y)
        loss = loss_model(output,label) # 损失函数计算
        loss.backward() #反向传播
        loss_all += loss.item() # 将最后的损失值汇总
        optimizer.step() # 更新模型参数
    tmp = (loss_all / len(train_dataset)) # 算出损失值或者错误率
    if epoch % 20 == 0:
        print(tmp) # 每二十次训练完整个训练数据集,输出其错误率
# 保存整个model的状态,也就是model的预训练模型
torch.save(model, "E:\GCNmodel\model\MyGCNmodel.pt") # 没有定义绝对路径情况下和此文件同文件夹

四、加载模型的方法(注意:文件绝对路径/模型文件名.pt和保存模型的要完全对应否则会报错)

import torch
model=torch.load("文件绝对路径/模型文件名.pt") # 加载模型,文件绝对路径/模型文件名.pt是保存训练好的模型的绝对路径和模型名称为模型文件名.pt
# 举例加载模型说明
import torch
from torch_geometric.data import DataLoader
import numpy as np
import warnings
from sklearn.metrics import accuracy_score
warnings.filterwarnings("ignore", category=Warning)
from model import test_dataset
# 导入已训练好的GCNmodel预训练模型
model=torch.load("E:\GCNmodel\model\MyGCNmodel.pt")
# 测试
preds = [] # 预测标签列表
label = [] # 真实标签列表
loaders = DataLoader(test_dataset, batch_size=20, shuffle=False) # 读取测试数据集数据
with torch.no_grad():
    for predata in loaders:
        pred = model(predata).numpy()
        label.append(predata.y.tolist())
        for i in range(pred.shape[0]):
            tmp = pred[i].tolist()  # tensor转成列表,pred[i]表示第i张超图
            # print(tmp.index(max(tmp)))
            preds.append(tmp.index(max(tmp)))  # 从列表的两个元素选出最大的tmp.index(x)返回寻找元素x的下标,此时只有两个元素那么下标就是0和1
        preds = np.squeeze(np.array(preds)).tolist()
    # 真实超图(图包)的标签数据集
    label = [i for item in label for i in item]
# 输出结果和统计模型预测正确率
print(preds) # 输出预测的超图(图包)标签
print(label) # 输出真实的超图(图包)标签
print(accuracy_score(label, preds))  # 求出分类准确率分数是指所有分类正确的百分比率,完全正确为1


相关文章
|
1月前
|
机器学习/深度学习 人工智能 PyTorch
PyTorch深度学习 ? 带你从入门到精通!!!
🌟 蒋星熠Jaxonic,深度学习探索者。三年深耕PyTorch,从基础到部署,分享模型构建、GPU加速、TorchScript优化及PyTorch 2.0新特性,助力AI开发者高效进阶。
PyTorch深度学习 ? 带你从入门到精通!!!
|
2月前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
175 1
|
1月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
|
2月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
132 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
1月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
3月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
239 9
|
2月前
|
机器学习/深度学习 数据采集 传感器
【WOA-CNN-LSTM】基于鲸鱼算法优化深度学习预测模型的超参数研究(Matlab代码实现)
【WOA-CNN-LSTM】基于鲸鱼算法优化深度学习预测模型的超参数研究(Matlab代码实现)
184 0
|
10月前
|
机器学习/深度学习 运维 安全
深度学习在安全事件检测中的应用:守护数字世界的利器
深度学习在安全事件检测中的应用:守护数字世界的利器
398 22
|
7月前
|
机器学习/深度学习 编解码 人工智能
计算机视觉五大技术——深度学习在图像处理中的应用
深度学习利用多层神经网络实现人工智能,计算机视觉是其重要应用之一。图像分类通过卷积神经网络(CNN)判断图片类别,如“猫”或“狗”。目标检测不仅识别物体,还确定其位置,R-CNN系列模型逐步优化检测速度与精度。语义分割对图像每个像素分类,FCN开创像素级分类范式,DeepLab等进一步提升细节表现。实例分割结合目标检测与语义分割,Mask R-CNN实现精准实例区分。关键点检测用于人体姿态估计、人脸特征识别等,OpenPose和HRNet等技术推动该领域发展。这些方法在效率与准确性上不断进步,广泛应用于实际场景。
971 64
计算机视觉五大技术——深度学习在图像处理中的应用
|
11月前
|
机器学习/深度学习 传感器 数据采集
深度学习在故障检测中的应用:从理论到实践
深度学习在故障检测中的应用:从理论到实践
911 6

热门文章

最新文章

推荐镜像

更多