深度学习实践篇 第五章:模型保存与加载

简介: 简要介绍pytorch中模型的保存与加载。

参考教程
https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

训练好的模型,可以保存下来,用于后续的预测或者训练过程的重启。
为了便于理解模型保存和加载的过程,我们定义一个简单的小模型作为例子,进行后续的讲解。

这个模型里面包含一个名为self.p1的Parameter和一个名为conv1的卷积层。我们没有给模型定义forward()函数,是因为暂时不需要用到该方法。假如你想使用这个模型对数据进行前向传播,会返回 “NotImplementedError: Module [Model] is missing the required "forward" function”

import torch
import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.t1 = torch.randn((3,2))
        self.p1 = nn.Parameter(self.t1)
        self.conv1 = nn.Conv2d(1, 1, 5)
net = Model()

pytorch中的保存与加载

首先我们来看一下pytorch中的保存和加载的方法是怎么实现的。

torch.save()

参考文档:https://pytorch.org/docs/stable/generated/torch.save.html
首先来看一下torch.save()函数。

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

torch.save()函数传入的第一个参数,就是我们要保存的对象,它的类别要求是object,而没有限定在nn.Module()或者nn.Parameters()等等之间。说明它可以保存的类型是多种多样的,很灵活。
传入的第二个参数是f,f是一个file-like object或者文件路径,也就是我们想要保存的位置。
后面的几个参数可以不用管它,一般也不会用到。从参数名称可以看到,我们想要保存的object是以pickle的形式保存的。因为pickle支持多种数据类型。
在源码中给了两个使用torch.save的例子。

  >>> # xdoctest: +SKIP("makes cwd dirty")
        >>> # Save to file
        >>> x = torch.tensor([0, 1, 2, 3, 4])
        >>> torch.save(x, 'tensor.pt')
        >>> # Save to io.BytesIO buffer
        >>> buffer = io.BytesIO()
        >>> torch.save(x, buffer)

第一个例子把一个tensor保存在了‘tensor.pt'中,第二个则是将tensor保存在一个buffer中。这都是允许的。

torch.load()

参考文档:https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
再来看一下torch.load()函数。

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, **pickle_load_args)

torch.load()传入的第一个参数f对应着torch.save()中的f,它可以是一个路径,也可以是一个file-like object。
因为我们的模型训练支持cpu也支持gpu等设备,所以我们保存的object也可能处于多种设备环境中,在torch.load()时,这个object会现在CPU上进行反序列化,然后移动到其保存时所处的设备上。假如当前的系统不支持这个设备,就会出现问题,这个时候就需要使用map_location参数,这个参数可以指定你想要放置object的设备,假如没有特别指定,在设备不能实现时就会报错。
weights_only参数可以限定你先要unpickle的object的种类,在使用weights_only参数的同时,你必须明确定义pickle_moduel这个参数(默认为pickle,这也是对的),否则就会报错RuntimeError("Can not safely load weights when explicit pickle_module is specified"。一般情况下我们也不需要管这个参数。

代码示例

给出一个简单的例子,我们将一个tensor保存在’tensor.pt'中,又使用torch.load()加载进来。
image.png

因为保存支持的输入是object,所以我们即使只保存一个字符串也是可以的。(可以,但没必要)
image.png

模型的保存与加载

保存 state_dict()

在之前的章节中有说过,调用model.state_dict()方法时,得到的返回结果是一个orderdict,这个字典的key是模型中参数的名字,value是模型的参数值。
我们通常说的保存模型,保存的就是模型的state_dict(),也就是只保存了模型的参数名和参数值,因此我们是不知道模型的正确结构和forward()中的运算顺序的,你也没有办法直接使用这个state_dict()进行预测。
现在我们保存最开始定义的笨蛋小模型的state_dict()
image.png

我们只保存了模型的参数名和参数值,这个'test.pth'的大小只有1.39 KB (1,428 字节)。

nn.Module().load_state_dict()

def load_state_dict(self, state_dict: Mapping[str, Any],
                        strict: bool = True):

load_state_dict()传入的参数是一个key和value的mapping。这里的keys对应的当前模型自己的state_dict的key,或者说参数名。
在使用load_state_dict()时,该方法会对传入的mapping中的key和模型本身的key进行对比。如果key可以匹配上,就会进行一些操作后,更改模型的key对应的参数值。假如没有匹配上,这个key就会被放进missing_keys或者unexpected_keys中去。
strict这个参数默认是True,所以当有不匹配的key时,就会返回报错。

加载模型参数

我们只保存的模型的参数,所以想要使用这个参数,就需要把它放置在一个现有的模型中去。比如说我们现在有一个新模型model2,它和model1有着一样的结构,但是因为初始化的随机性,它们的参数值可能是不一样的。
image.png

可以看到我们的model2中的参数名和model1一样,但是对应的值不一样。
我们可以使用load_state_dict()方法将model1的参数值根据参数名放到model2中去。
image.png

现在model1和model2中的参数值也都变得一样了。
假如我们手动修改一下我们使用torch.load()加载的state_dict,给它增加一个新的值。加载时就会报错,出现了unexpected_keys。相应地,假如给它删除一个值,就会出现Missing key(s) 的错误,在这里不举例子。
image.png

保存模型本身

torch.save()支持保存的对象是object,而我们的模型本身,作为nn.Module(),自然也是符合object的要求的。因此你也可以直接保存整个模型。
image.png

我们保存的是整个模型,包括了模型的结构和模型的参数名+参数值。这个'test2.pth'的大小是2.39 KB (2,457 字节)。

加载模型本身

我们在上面将整个模型都保存在了'test2.pth'中,因此我们使用torch.load('test2.pth)时,获得的结果就是模型本身,它的类型是nn.Module()。
image.png

checkpoint

保存与读取

假如我们现在有一个保存好的模型'model.pth',我们想要继续当前模型的状态继续训练。这个时候我们就会发现,'model.pth'中拥有我们模型的参数名和参数值,但是随着我们之前的训练的进行,我们使用的optimizer或者lr_scheluder的状态我们是无法获取的,它们中也有一些参数可能在训练时发生了变化。
因此为了帮助我们重启训练状态,我们需要保存更多的信息,而不是只保存一个模型的state_dict。这些被保存的信息,统称为checkpoint。
在保存checkpoint时,我们同样使用torch.save()方法,在加载时,也是用torch.load()方法。因为torch.save支持保存各种格式,我们可以将想要保存的信息按照key和value组成一个dict,并将这个dict保存下来。
在下面这个例子中,被保存下来的信息包括当前的epoch数,模型的state_dict, 优化器的state_dict还有louss。

# Additional information
torch.save({
   
   
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

在加载时,我们只要按照key取其中的value就可以。

# Additional information
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

多个模型的保存与读取

我们已经知道可以将key和value对应的dict保存成checkpoint的形式,帮助我们重启训练状态。当我们有多个模型时,只不过是增加了要保存到信息而已,方法是一样的。

# Specify a path to save to
PATH = "model.pt"

torch.save({
   
   
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

在这个checkpoint中,我们分别保存了modelA和modelB的state_dict,和它们对应的优化器optimizerA和optimizerB的state_dict。
因此在使用时,只要分别放置到对应的object中就可以。

modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
相关文章
|
4月前
|
机器学习/深度学习 算法 定位技术
Baumer工业相机堡盟工业相机如何通过YoloV8深度学习模型实现裂缝的检测识别(C#代码UI界面版)
本项目基于YOLOv8模型与C#界面,结合Baumer工业相机,实现裂缝的高效检测识别。支持图像、视频及摄像头输入,具备高精度与实时性,适用于桥梁、路面、隧道等多种工业场景。
525 27
|
3月前
|
机器学习/深度学习 数据可视化 算法
深度学习模型结构复杂、参数众多,如何更直观地深入理解你的模型?
深度学习模型虽应用广泛,但其“黑箱”特性导致可解释性不足,尤其在金融、医疗等敏感领域,模型决策逻辑的透明性至关重要。本文聚焦深度学习可解释性中的可视化分析,介绍模型结构、特征、参数及输入激活的可视化方法,帮助理解模型行为、提升透明度,并推动其在关键领域的安全应用。
364 0
|
2月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
151 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
1月前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
4月前
|
机器学习/深度学习 人工智能 PyTorch
AI 基础知识从 0.2 到 0.3——构建你的第一个深度学习模型
本文以 MNIST 手写数字识别为切入点,介绍了深度学习的基本原理与实现流程,帮助读者建立起对神经网络建模过程的系统性理解。
597 15
AI 基础知识从 0.2 到 0.3——构建你的第一个深度学习模型
|
2月前
|
机器学习/深度学习 数据采集 传感器
【WOA-CNN-LSTM】基于鲸鱼算法优化深度学习预测模型的超参数研究(Matlab代码实现)
【WOA-CNN-LSTM】基于鲸鱼算法优化深度学习预测模型的超参数研究(Matlab代码实现)
206 0
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
AI 基础知识从 0.3 到 0.4——如何选对深度学习模型?
本系列文章从机器学习基础出发,逐步深入至深度学习与Transformer模型,探讨AI关键技术原理及应用。内容涵盖模型架构解析、典型模型对比、预训练与微调策略,并结合Hugging Face平台进行实战演示,适合初学者与开发者系统学习AI核心知识。
440 15
|
4月前
|
机器学习/深度学习 人工智能 自然语言处理
深度学习模型、算法与应用的全方位解析
深度学习,作为人工智能(AI)的一个重要分支,已经在多个领域产生了革命性的影响。从图像识别到自然语言处理,从语音识别到自动驾驶,深度学习无处不在。本篇博客将深入探讨深度学习的模型、算法及其在各个领域的应用。
880 3
|
5月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
255 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
|
9月前
|
机器学习/深度学习 数据采集 自然语言处理
深度学习实践技巧:提升模型性能的详尽指南
深度学习模型在图像分类、自然语言处理、时间序列分析等多个领域都表现出了卓越的性能,但在实际应用中,为了使模型达到最佳效果,常规的标准流程往往不足。本文提供了多种深度学习实践技巧,包括数据预处理、模型设计优化、训练策略和评价与调参等方面的详细操作和代码示例,希望能够为应用实战提供有效的指导和支持。
下一篇
oss云网关配置