通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()

简介: 通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()

0. 前言

深度学习实际应用中,往往涉及到的神经元网络模型都很大,权重参数众多,因此会导致训练epoch次数很多,训练时间长。


如果每次调整非模型相关的参数(训练数据集、优化函数类型、学习率、迭代次数)都要重新训练一次模型,这显然会浪费大量的训练时间。


而且,对于一些成熟的网络模型,已经有前人做过大量的“预训练”,这时如果能基于前人预训练的结果,训练自己的数据集,明显会事半功倍。


因此,加载与保存权重在深度学习实际使用中有很大的必要。


1. Pytorch框架加载与保存权重的方法

①加载权重的方法: .load_state_dict()方法说明:

.load_state_dict()定义:

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',strict: bool = True):

- state_dict :即要加载的权重,通常是一个文件地址;

- strick: 可以理解为等于"True"时是“精确匹配”,要求要加载的权重与要被加载权重的模型完全匹配。

Pytorch源文件注释:

Args:

   state_dict (dict): a dict containing parameters and

       persistent buffers.

   strict (bool, optional): whether to strictly enforce that the keys

       in :attr:`state_dict` match the keys returned by this module's

       :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

*小注释:meth笔误了,应该是mesh,网格

②保存权重的方法:.save()方法说明:

.save()定义:

def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
         pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:

- obj:要保存的权重参数;

- f:保存的文件路径

这里仅说明.save()在保存网络模型权重数据上的作用。实际上.save()还有很多应用,例如:保存整个网络,这里不再赘述。

Pytorch源文件注释:

"""Saves an object to a disk file.


See also: `saving-loading-tensors`


Args:

   obj: saved object

   f: a file-like object (has to implement write and flush) or a string or

      os.PathLike object containing a file name

   pickle_module: module used for pickling metadata and objects

   pickle_protocol: can be specified to override the default protocol

2. 实例问题说明


首先说明本次的实例问题:本次要构建的神经元网络为一个“平方网络”,即网络输出数据为输入数据的平方。


网络模型结构:


输入(1)→全连接层(1×5)→Sigmoid激活函数(5)→全连接层(5×5)→Sigmoid激活函数(5)→全连接层(5×1)→输出(1)


训练数据:


输入数据[1, 2, 3, 4, 5];

输出数据[1, 4, 9, 16, 25]


3. 加载权重数据

直接上代码


import torch
 
class LinearNet(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features= 5, out_features=5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=output_size, bias=True)
        )
 
    def forward(self,x):
        return self.net(x)
 
square_net = LinearNet(1,1)
 
square_net.load_state_dict(torch.load('weight.pth'))  #直接加载已经训练好的权重
 
if __name__ == '__main__':
    print(square_net(torch.tensor([3.16],dtype=torch.float32)))


其中weight.pth是我已经训练好的权重数据路径,这里定义好网络模型后,直接加载权重数据,不必关心这个权重是如何训练来的,更不必关系具体权重的值是多少。测试输入为3.16输出为:


tensor([9.9180], grad_fn=<AddBackward0>)


这里要注意的是:因为上面strict默认为True,即为“精确匹配”,这里新构建的网络模型结构必须和权重来源的网络模型结构相同


4. 保存权重数据

import torch
 
input = torch.tensor([[1],[2],[3],[4],[5]], dtype=torch.float32)
output = torch.tensor([[1],[4],[9],[16],[25]], dtype=torch.float32)
 
class LinearNet(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features= 5, out_features=5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=output_size, bias=True)
        )
 
    def forward(self,x):
        return self.net(x)
 
Loss = torch.nn.MSELoss()
linear_net = LinearNet(1,1)
opt = torch.optim.SGD(linear_net.parameters(), lr= 0.003)
 
for k in range(1000):
    opt.zero_grad()
    for i in range(len(input)):
        train_out = linear_net(input[i])
        loss = Loss(train_out, output[i])
 
        loss.backward()
        opt.step()
 
torch.save(linear_net.state_dict(),'weight.pth')   #保存.pth权重文件
 
for keys,values in linear_net.state_dict().items():   #查看权重名称及值
    print(keys)
    print(values)
    print('************************************************************************')
 
if __name__ == '__main__':
    print(linear_net(torch.tensor([3.16],dtype=torch.float32)))


这里可以看到具体的训练过程及相关的训练参数,权重保存在'weight.pth'文件中。


可以通过print查看具体的权重数值:


net.0.weight
tensor([[-0.8204],
        [-1.7341],
        [-0.6987],
        [ 0.9370],
        [-1.5558]])
************************************************************************
net.0.bias
tensor([ 0.9285,  2.1061,  1.0247, -2.9221,  7.1159])
************************************************************************
net.2.weight
tensor([[-1.6075, -1.3072, -1.5342,  2.4527, -3.9922],
        [-0.7101, -1.5125, -0.6791,  2.0325, -2.3406],
        [-1.1707, -1.6899, -0.9883,  2.9682, -1.5409],
        [-1.1992, -2.0559, -0.7610,  2.3890, -1.3782],
        [-1.1274, -1.7907, -1.0860,  2.3549, -3.6847]])
************************************************************************
net.2.bias
tensor([0.4826, 0.7057, 0.9702, 1.0532, 0.4214])
************************************************************************
net.4.weight
tensor([[7.3601, 4.7667, 6.2473, 5.0187, 7.2028]])
************************************************************************
net.4.bias
tensor([-0.2476])
************************************************************************


相关文章
|
1月前
|
PyTorch Linux 算法框架/工具
pytorch学习一:Anaconda下载、安装、配置环境变量。anaconda创建多版本python环境。安装 pytorch。
这篇文章是关于如何使用Anaconda进行Python环境管理,包括下载、安装、配置环境变量、创建多版本Python环境、安装PyTorch以及使用Jupyter Notebook的详细指南。
249 1
pytorch学习一:Anaconda下载、安装、配置环境变量。anaconda创建多版本python环境。安装 pytorch。
|
5月前
|
机器学习/深度学习 自然语言处理 算法
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
【从零开始学习深度学习】49.Pytorch_NLP项目实战:文本情感分类---使用循环神经网络RNN
|
1月前
|
机器学习/深度学习 缓存 PyTorch
pytorch学习一(扩展篇):miniconda下载、安装、配置环境变量。miniconda创建多版本python环境。整理常用命令(亲测ok)
这篇文章是关于如何下载、安装和配置Miniconda,以及如何使用Miniconda创建和管理Python环境的详细指南。
349 0
pytorch学习一(扩展篇):miniconda下载、安装、配置环境变量。miniconda创建多版本python环境。整理常用命令(亲测ok)
|
3月前
|
存储 PyTorch API
Pytorch入门—Tensors张量的学习
Pytorch入门—Tensors张量的学习
29 0
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
【从零开始学习深度学习】47. Pytorch图片样式迁移实战:将一张图片样式迁移至另一张图片,创作自己喜欢风格的图片【含完整源码】
|
5月前
|
机器学习/深度学习 资源调度 PyTorch
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
【从零开始学习深度学习】15. Pytorch实战Kaggle比赛:房价预测案例【含数据集与源码】
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
【从零开始学习深度学习】50.Pytorch_NLP项目实战:卷积神经网络textCNN在文本情感分类的运用
|
5月前
|
机器学习/深度学习 自然语言处理 PyTorch
【从零开始学习深度学习】48.Pytorch_NLP实战案例:如何使用预训练的词向量模型求近义词和类比词
【从零开始学习深度学习】48.Pytorch_NLP实战案例:如何使用预训练的词向量模型求近义词和类比词
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
5月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】44. 图像增广的几种常用方式并使用图像增广训练模型【Pytorch】
【从零开始学习深度学习】44. 图像增广的几种常用方式并使用图像增广训练模型【Pytorch】

热门文章

最新文章