深度学习实践篇 第五章:模型保存与加载

简介: 简要介绍pytorch中模型的保存与加载。

参考教程
https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html

训练好的模型,可以保存下来,用于后续的预测或者训练过程的重启。
为了便于理解模型保存和加载的过程,我们定义一个简单的小模型作为例子,进行后续的讲解。

这个模型里面包含一个名为self.p1的Parameter和一个名为conv1的卷积层。我们没有给模型定义forward()函数,是因为暂时不需要用到该方法。假如你想使用这个模型对数据进行前向传播,会返回 “NotImplementedError: Module [Model] is missing the required "forward" function”

import torch
import torch.nn as nn
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.t1 = torch.randn((3,2))
        self.p1 = nn.Parameter(self.t1)
        self.conv1 = nn.Conv2d(1, 1, 5)
net = Model()

pytorch中的保存与加载

首先我们来看一下pytorch中的保存和加载的方法是怎么实现的。

torch.save()

参考文档:https://pytorch.org/docs/stable/generated/torch.save.html
首先来看一下torch.save()函数。

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)

torch.save()函数传入的第一个参数,就是我们要保存的对象,它的类别要求是object,而没有限定在nn.Module()或者nn.Parameters()等等之间。说明它可以保存的类型是多种多样的,很灵活。
传入的第二个参数是f,f是一个file-like object或者文件路径,也就是我们想要保存的位置。
后面的几个参数可以不用管它,一般也不会用到。从参数名称可以看到,我们想要保存的object是以pickle的形式保存的。因为pickle支持多种数据类型。
在源码中给了两个使用torch.save的例子。

  >>> # xdoctest: +SKIP("makes cwd dirty")
        >>> # Save to file
        >>> x = torch.tensor([0, 1, 2, 3, 4])
        >>> torch.save(x, 'tensor.pt')
        >>> # Save to io.BytesIO buffer
        >>> buffer = io.BytesIO()
        >>> torch.save(x, buffer)

第一个例子把一个tensor保存在了‘tensor.pt'中,第二个则是将tensor保存在一个buffer中。这都是允许的。

torch.load()

参考文档:https://pytorch.org/docs/stable/generated/torch.load.html#torch.load
再来看一下torch.load()函数。

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, **pickle_load_args)

torch.load()传入的第一个参数f对应着torch.save()中的f,它可以是一个路径,也可以是一个file-like object。
因为我们的模型训练支持cpu也支持gpu等设备,所以我们保存的object也可能处于多种设备环境中,在torch.load()时,这个object会现在CPU上进行反序列化,然后移动到其保存时所处的设备上。假如当前的系统不支持这个设备,就会出现问题,这个时候就需要使用map_location参数,这个参数可以指定你想要放置object的设备,假如没有特别指定,在设备不能实现时就会报错。
weights_only参数可以限定你先要unpickle的object的种类,在使用weights_only参数的同时,你必须明确定义pickle_moduel这个参数(默认为pickle,这也是对的),否则就会报错RuntimeError("Can not safely load weights when explicit pickle_module is specified"。一般情况下我们也不需要管这个参数。

代码示例

给出一个简单的例子,我们将一个tensor保存在’tensor.pt'中,又使用torch.load()加载进来。
image.png

因为保存支持的输入是object,所以我们即使只保存一个字符串也是可以的。(可以,但没必要)
image.png

模型的保存与加载

保存 state_dict()

在之前的章节中有说过,调用model.state_dict()方法时,得到的返回结果是一个orderdict,这个字典的key是模型中参数的名字,value是模型的参数值。
我们通常说的保存模型,保存的就是模型的state_dict(),也就是只保存了模型的参数名和参数值,因此我们是不知道模型的正确结构和forward()中的运算顺序的,你也没有办法直接使用这个state_dict()进行预测。
现在我们保存最开始定义的笨蛋小模型的state_dict()
image.png

我们只保存了模型的参数名和参数值,这个'test.pth'的大小只有1.39 KB (1,428 字节)。

nn.Module().load_state_dict()

def load_state_dict(self, state_dict: Mapping[str, Any],
                        strict: bool = True):

load_state_dict()传入的参数是一个key和value的mapping。这里的keys对应的当前模型自己的state_dict的key,或者说参数名。
在使用load_state_dict()时,该方法会对传入的mapping中的key和模型本身的key进行对比。如果key可以匹配上,就会进行一些操作后,更改模型的key对应的参数值。假如没有匹配上,这个key就会被放进missing_keys或者unexpected_keys中去。
strict这个参数默认是True,所以当有不匹配的key时,就会返回报错。

加载模型参数

我们只保存的模型的参数,所以想要使用这个参数,就需要把它放置在一个现有的模型中去。比如说我们现在有一个新模型model2,它和model1有着一样的结构,但是因为初始化的随机性,它们的参数值可能是不一样的。
image.png

可以看到我们的model2中的参数名和model1一样,但是对应的值不一样。
我们可以使用load_state_dict()方法将model1的参数值根据参数名放到model2中去。
image.png

现在model1和model2中的参数值也都变得一样了。
假如我们手动修改一下我们使用torch.load()加载的state_dict,给它增加一个新的值。加载时就会报错,出现了unexpected_keys。相应地,假如给它删除一个值,就会出现Missing key(s) 的错误,在这里不举例子。
image.png

保存模型本身

torch.save()支持保存的对象是object,而我们的模型本身,作为nn.Module(),自然也是符合object的要求的。因此你也可以直接保存整个模型。
image.png

我们保存的是整个模型,包括了模型的结构和模型的参数名+参数值。这个'test2.pth'的大小是2.39 KB (2,457 字节)。

加载模型本身

我们在上面将整个模型都保存在了'test2.pth'中,因此我们使用torch.load('test2.pth)时,获得的结果就是模型本身,它的类型是nn.Module()。
image.png

checkpoint

保存与读取

假如我们现在有一个保存好的模型'model.pth',我们想要继续当前模型的状态继续训练。这个时候我们就会发现,'model.pth'中拥有我们模型的参数名和参数值,但是随着我们之前的训练的进行,我们使用的optimizer或者lr_scheluder的状态我们是无法获取的,它们中也有一些参数可能在训练时发生了变化。
因此为了帮助我们重启训练状态,我们需要保存更多的信息,而不是只保存一个模型的state_dict。这些被保存的信息,统称为checkpoint。
在保存checkpoint时,我们同样使用torch.save()方法,在加载时,也是用torch.load()方法。因为torch.save支持保存各种格式,我们可以将想要保存的信息按照key和value组成一个dict,并将这个dict保存下来。
在下面这个例子中,被保存下来的信息包括当前的epoch数,模型的state_dict, 优化器的state_dict还有louss。

# Additional information
torch.save({
   
   
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

在加载时,我们只要按照key取其中的value就可以。

# Additional information
model = Net()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

多个模型的保存与读取

我们已经知道可以将key和value对应的dict保存成checkpoint的形式,帮助我们重启训练状态。当我们有多个模型时,只不过是增加了要保存到信息而已,方法是一样的。

# Specify a path to save to
PATH = "model.pt"

torch.save({
   
   
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

在这个checkpoint中,我们分别保存了modelA和modelB的state_dict,和它们对应的优化器optimizerA和optimizerB的state_dict。
因此在使用时,只要分别放置到对应的object中就可以。

modelA = Net()
modelB = Net()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()
相关文章
|
6天前
|
机器学习/深度学习 算法 测试技术
Python中实现多层感知机(MLP)的深度学习模型
Python中实现多层感知机(MLP)的深度学习模型
37 0
|
3天前
|
机器学习/深度学习 算法
揭秘深度学习中的对抗性网络:理论与实践
【5月更文挑战第18天】 在深度学习领域的众多突破中,对抗性网络(GANs)以其独特的机制和强大的生成能力受到广泛关注。不同于传统的监督学习方法,GANs通过同时训练生成器与判别器两个模型,实现了无监督学习下的高效数据生成。本文将深入探讨对抗性网络的核心原理,解析其数学模型,并通过案例分析展示GANs在图像合成、风格迁移及增强学习等领域的应用。此外,我们还将讨论当前GANs面临的挑战以及未来的发展方向,为读者提供一个全面而深入的视角以理解这一颠覆性技术。
|
4天前
|
机器学习/深度学习 人工智能 算法
【AI】从零构建深度学习框架实践
【5月更文挑战第16天】 本文介绍了从零构建一个轻量级的深度学习框架tinynn,旨在帮助读者理解深度学习的基本组件和框架设计。构建过程包括设计框架架构、实现基本功能、模型定义、反向传播算法、训练和推理过程以及性能优化。文章详细阐述了网络层、张量、损失函数、优化器等组件的抽象和实现,并给出了一个基于MNIST数据集的分类示例,与TensorFlow进行了简单对比。tinynn的源代码可在GitHub上找到,目前支持多种层、损失函数和优化器,适用于学习和实验新算法。
57 2
|
4天前
|
机器学习/深度学习 数据可视化 PyTorch
使用Python实现深度学习模型:变分自编码器(VAE)
使用Python实现深度学习模型:变分自编码器(VAE)
13 2
|
5天前
|
机器学习/深度学习 数据可视化 PyTorch
使用Python实现深度学习模型:生成对抗网络(GAN)
使用Python实现深度学习模型:生成对抗网络(GAN)
20 3
|
5天前
|
机器学习/深度学习 数据可视化 PyTorch
使用Python实现深度学习模型:自动编码器(Autoencoder)
使用Python实现深度学习模型:自动编码器(Autoencoder)
10 0
|
6天前
|
机器学习/深度学习 数据采集 人工智能
深度学习中的大模型「幻觉」问题:解析、原因及未来展望
深度学习中的大模型「幻觉」问题:解析、原因及未来展望
25 0
|
6天前
|
机器学习/深度学习 TensorFlow API
Python深度学习基于Tensorflow(3)Tensorflow 构建模型
Python深度学习基于Tensorflow(3)Tensorflow 构建模型
71 2
|
6天前
|
机器学习/深度学习 自然语言处理 算法
Python遗传算法GA对长短期记忆LSTM深度学习模型超参数调优分析司机数据|附数据代码
Python遗传算法GA对长短期记忆LSTM深度学习模型超参数调优分析司机数据|附数据代码
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
深度理解深度学习:从理论到实践的探索
【5月更文挑战第3天】 在人工智能的浪潮中,深度学习以其卓越的性能和广泛的应用成为了研究的热点。本文将深入探讨深度学习的核心理论,解析其背后的数学原理,并通过实际案例分析如何将这些理论应用于解决现实世界的问题。我们将从神经网络的基础结构出发,逐步过渡到复杂的模型架构,同时讨论优化算法和正则化技巧。通过本文,读者将对深度学习有一个全面而深刻的认识,并能够在实践中更加得心应手地应用这些技术。