【从零开始学习深度学习】19. Pytorch中如何存储与读取模型:torch.save、torch.load与state_dict对象

简介: 【从零开始学习深度学习】19. Pytorch中如何存储与读取模型:torch.save、torch.load与state_dict对象

1. 读写Tensor

我们可以直接使用save函数和load函数分别存储和读取Tensorsave使用Python的pickle实用程序将对象进行序列化,然后将序列化的对象保存到disk,使用save可以保存各种对象,包括模型、张量和字典等。而load使用pickle unpickle工具将pickle的对象文件反序列化为内存。

下面的例子创建了Tensor变量x,并将其存在文件名同为x.pt的文件里。

import torch
from torch import nn
x = torch.ones(3)
torch.save(x, 'x.pt')

然后我们将数据从存储的文件读回内存。

x2 = torch.load('x.pt')
x2

输出:

tensor([1., 1., 1.])

我们还可以存储一个Tensor列表并读回内存。

y = torch.zeros(4)
torch.save([x, y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

输出:

[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

存储并读取一个从字符串映射到Tensor的字典。

torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy = torch.load('xy_dict.pt')
xy

输出:

{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

2. 读写模型

2.1 state_dict

在PyTorch中,Module的可学习参数(即权重和偏差),模块模型包含在参数中(通过model.parameters()访问)。state_dict是一个从参数名称映射到参数Tesnor的字典对象。

class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)
    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)
net = MLP()
net.state_dict()

输出:

OrderedDict([('hidden.weight', tensor([[ 0.2448,  0.1856, -0.5678],
                      [ 0.2030, -0.2073, -0.0104]])),
             ('hidden.bias', tensor([-0.3117, -0.4232])),
             ('output.weight', tensor([[-0.4556,  0.4084]])),
             ('output.bias', tensor([-0.3573]))])

注意,只有具有可学习参数的层(卷积层、线性层等)才有state_dict中的条目。优化器(optim)也有一个state_dict,其中包含关于优化器状态以及所使用的超参数的信息。

optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()

输出:

{'param_groups': [{'dampening': 0,
   'lr': 0.001,
   'momentum': 0.9,
   'nesterov': False,
   'params': [4736167728, 4736166648, 4736167368, 4736165352],
   'weight_decay': 0}],
 'state': {}}

2.2 保存和加载模型

PyTorch中保存和加载训练模型有两种常见的方法:

  1. 仅保存和加载模型参数(state_dict);
  2. 保存和加载整个模型。
1. 保存和加载state_dict(推荐方式)

保存:

torch.save(model.state_dict(), PATH) # 推荐文件后缀名是pt或pth

加载:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
2. 保存和加载整个模型

保存:

torch.save(model, PATH)

加载:

model = torch.load(PATH)

我们采用推荐的方法一来实验一下:

X = torch.randn(2, 3)
Y = net(X)
PATH = "./net.pt"
torch.save(net.state_dict(), PATH)
net2 = MLP()
net2.load_state_dict(torch.load(PATH))
Y2 = net2(X)
Y2 == Y

输出:

tensor([[1],
        [1]], dtype=torch.uint8)

因为这netnet2都有同样的模型参数,那么对同一个输入X的计算结果将会是一样的。上面的输出也验证了这一点。

此外,还有一些其他使用场景,例如GPU与CPU之间的模型保存与读取、使用多块GPU的模型的存储等等,使用的时候可以参考官方文档

总结

  • 通过save函数和load函数可以很方便地读写Tensor
  • 通过save函数和load_state_dict函数可以很方便地读写模型的参数。
相关文章
|
7小时前
|
机器学习/深度学习 数据采集 TensorFlow
深度学习与传统模型的桥梁:Sklearn与Keras的集成应用
【7月更文第24天】在机器学习领域,Scikit-learn(Sklearn)作为经典的传统机器学习库,以其丰富的预处理工具、模型选择和评估方法而闻名;而Keras作为深度学习领域的明星框架,以其简洁易用的API,支持快速构建和实验复杂的神经网络模型。将这两者结合起来,可以实现从传统机器学习到深度学习的无缝过渡,充分发挥各自的优势,打造更强大、更灵活的解决方案。本文将探讨Sklearn与Keras的集成应用,通过实例展示如何在Sklearn的生态系统中嵌入Keras模型,实现模型的训练、评估与优化。
13 0
|
1天前
|
机器学习/深度学习 搜索推荐 TensorFlow
使用Python实现深度学习模型:个性化推荐与广告优化
【7月更文挑战第22天】 使用Python实现深度学习模型:个性化推荐与广告优化
125 70
|
2天前
|
机器学习/深度学习 存储 自然语言处理
使用Python实现深度学习模型:语言翻译与多语种处理
【7月更文挑战第21天】 使用Python实现深度学习模型:语言翻译与多语种处理
13 0
|
3天前
|
机器学习/深度学习 自然语言处理 机器人
使用Python实现深度学习模型:自然语言理解与问答系统
【7月更文挑战第20天】 使用Python实现深度学习模型:自然语言理解与问答系统
12 0
使用Python实现深度学习模型:自然语言理解与问答系统
|
5天前
|
机器学习/深度学习 TensorFlow 语音技术
使用Python实现深度学习模型:语音合成与语音转换
【7月更文挑战第19天】 使用Python实现深度学习模型:语音合成与语音转换
19 1
|
1天前
|
机器学习/深度学习 计算机视觉
探索深度学习在图像识别中的应用与挑战
随着人工智能技术的飞速发展,深度学习已经成为图像识别领域的核心技术之一。本文将深入探讨深度学习模型在图像识别任务中的应用,分析其背后的技术原理,并指出当前面临的主要挑战和未来的发展方向。通过具体案例分析,本文旨在为读者提供对深度学习在图像识别中作用的全面理解,同时指出技术进步带来的潜在问题和伦理考量。
12 1
|
1天前
|
机器学习/深度学习 自然语言处理 监控
深度学习在自然语言处理中的应用与挑战
随着人工智能技术的飞速发展,深度学习已成为自然语言处理(NLP)领域的核心技术之一。本文深入探讨了深度学习在自然语言处理中的广泛应用,包括语音识别、机器翻译、文本分类等,并分析了当前面临的主要挑战,如模型的可解释性、数据偏差问题以及对抗性样本的防御。通过综合最新的研究成果和实际应用案例,本文旨在为读者提供一个关于深度学习在NLP领域应用的全面视角,并对未来发展趋势进行展望。
13 2
|
1天前
|
机器学习/深度学习 边缘计算 监控
深度学习在图像识别中的应用与挑战
随着人工智能技术的飞速发展,深度学习已成为推动计算机视觉领域进步的核心动力。本文旨在探讨深度学习在图像识别中的关键应用,分析其面临的主要技术挑战,并提出相应的解决策略。通过深入讨论卷积神经网络(CNN)的结构优化、数据增强技术、以及对抗性攻击的防御机制,本文揭示了深度学习模型在处理复杂视觉任务时的强大潜力和存在的局限性。此外,文章还关注了深度学习在实际应用中的资源消耗问题,并提出了提高效率的可能途径。
10 2
|
1天前
|
机器学习/深度学习 计算机视觉
深度学习在医疗影像分析中的应用
深度学习技术在医疗影像分析领域取得了显著进展,为医生提供了更准确、更快速的诊断工具。本文将探讨深度学习在医疗影像分析中的应用,包括图像分类、目标检测和分割等方面。通过具体的案例和数据分析,我们将展示深度学习如何提高医疗影像分析的准确性和效率,并讨论其在未来发展的潜力和挑战。
|
1天前
|
机器学习/深度学习 人工智能 监控
探索深度学习在图像识别中的应用与挑战
随着计算能力的飞速提升和大数据时代的来临,深度学习已经成为推动人工智能发展的核心动力。特别是在图像识别领域,深度学习技术通过模拟人脑处理信息的机制,已经取得了令人瞩目的成就。本文将深入探讨深度学习在图像识别中的关键技术、应用场景以及面临的主要挑战,为读者提供一篇内容丰富、数据支撑的技术分析文章。