通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()

简介: 通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()

0. 前言

深度学习实际应用中,往往涉及到的神经元网络模型都很大,权重参数众多,因此会导致训练epoch次数很多,训练时间长。


如果每次调整非模型相关的参数(训练数据集、优化函数类型、学习率、迭代次数)都要重新训练一次模型,这显然会浪费大量的训练时间。


而且,对于一些成熟的网络模型,已经有前人做过大量的“预训练”,这时如果能基于前人预训练的结果,训练自己的数据集,明显会事半功倍。


因此,加载与保存权重在深度学习实际使用中有很大的必要。


1. Pytorch框架加载与保存权重的方法

①加载权重的方法: .load_state_dict()方法说明:

.load_state_dict()定义:

def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]',strict: bool = True):

- state_dict :即要加载的权重,通常是一个文件地址;

- strick: 可以理解为等于"True"时是“精确匹配”,要求要加载的权重与要被加载权重的模型完全匹配。

Pytorch源文件注释:

Args:

   state_dict (dict): a dict containing parameters and

       persistent buffers.

   strict (bool, optional): whether to strictly enforce that the keys

       in :attr:`state_dict` match the keys returned by this module's

       :meth:`~torch.nn.Module.state_dict` function. Default: ``True``

*小注释:meth笔误了,应该是mesh,网格

②保存权重的方法:.save()方法说明:

.save()定义:

def save(obj, f: Union[str, os.PathLike, BinaryIO, IO[bytes]],
         pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True) -> None:

- obj:要保存的权重参数;

- f:保存的文件路径

这里仅说明.save()在保存网络模型权重数据上的作用。实际上.save()还有很多应用,例如:保存整个网络,这里不再赘述。

Pytorch源文件注释:

"""Saves an object to a disk file.


See also: `saving-loading-tensors`


Args:

   obj: saved object

   f: a file-like object (has to implement write and flush) or a string or

      os.PathLike object containing a file name

   pickle_module: module used for pickling metadata and objects

   pickle_protocol: can be specified to override the default protocol

2. 实例问题说明


首先说明本次的实例问题:本次要构建的神经元网络为一个“平方网络”,即网络输出数据为输入数据的平方。


网络模型结构:


输入(1)→全连接层(1×5)→Sigmoid激活函数(5)→全连接层(5×5)→Sigmoid激活函数(5)→全连接层(5×1)→输出(1)


训练数据:


输入数据[1, 2, 3, 4, 5];

输出数据[1, 4, 9, 16, 25]


3. 加载权重数据

直接上代码


import torch
 
class LinearNet(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features= 5, out_features=5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=output_size, bias=True)
        )
 
    def forward(self,x):
        return self.net(x)
 
square_net = LinearNet(1,1)
 
square_net.load_state_dict(torch.load('weight.pth'))  #直接加载已经训练好的权重
 
if __name__ == '__main__':
    print(square_net(torch.tensor([3.16],dtype=torch.float32)))


其中weight.pth是我已经训练好的权重数据路径,这里定义好网络模型后,直接加载权重数据,不必关心这个权重是如何训练来的,更不必关系具体权重的值是多少。测试输入为3.16输出为:


tensor([9.9180], grad_fn=<AddBackward0>)


这里要注意的是:因为上面strict默认为True,即为“精确匹配”,这里新构建的网络模型结构必须和权重来源的网络模型结构相同


4. 保存权重数据

import torch
 
input = torch.tensor([[1],[2],[3],[4],[5]], dtype=torch.float32)
output = torch.tensor([[1],[4],[9],[16],[25]], dtype=torch.float32)
 
class LinearNet(torch.nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_features=input_size, out_features= 5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features= 5, out_features=5, bias=True),
            torch.nn.Sigmoid(),
            torch.nn.Linear(in_features=5, out_features=output_size, bias=True)
        )
 
    def forward(self,x):
        return self.net(x)
 
Loss = torch.nn.MSELoss()
linear_net = LinearNet(1,1)
opt = torch.optim.SGD(linear_net.parameters(), lr= 0.003)
 
for k in range(1000):
    opt.zero_grad()
    for i in range(len(input)):
        train_out = linear_net(input[i])
        loss = Loss(train_out, output[i])
 
        loss.backward()
        opt.step()
 
torch.save(linear_net.state_dict(),'weight.pth')   #保存.pth权重文件
 
for keys,values in linear_net.state_dict().items():   #查看权重名称及值
    print(keys)
    print(values)
    print('************************************************************************')
 
if __name__ == '__main__':
    print(linear_net(torch.tensor([3.16],dtype=torch.float32)))


这里可以看到具体的训练过程及相关的训练参数,权重保存在'weight.pth'文件中。


可以通过print查看具体的权重数值:


net.0.weight
tensor([[-0.8204],
        [-1.7341],
        [-0.6987],
        [ 0.9370],
        [-1.5558]])
************************************************************************
net.0.bias
tensor([ 0.9285,  2.1061,  1.0247, -2.9221,  7.1159])
************************************************************************
net.2.weight
tensor([[-1.6075, -1.3072, -1.5342,  2.4527, -3.9922],
        [-0.7101, -1.5125, -0.6791,  2.0325, -2.3406],
        [-1.1707, -1.6899, -0.9883,  2.9682, -1.5409],
        [-1.1992, -2.0559, -0.7610,  2.3890, -1.3782],
        [-1.1274, -1.7907, -1.0860,  2.3549, -3.6847]])
************************************************************************
net.2.bias
tensor([0.4826, 0.7057, 0.9702, 1.0532, 0.4214])
************************************************************************
net.4.weight
tensor([[7.3601, 4.7667, 6.2473, 5.0187, 7.2028]])
************************************************************************
net.4.bias
tensor([-0.2476])
************************************************************************


相关文章
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【单点知识】基于实例详解PyTorch中的DataLoader类
【单点知识】基于实例详解PyTorch中的DataLoader类
171 2
|
2月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】自调整学习率实例应用(附代码)
【PyTorch实战演练】自调整学习率实例应用(附代码)
53 0
|
24天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【单点知识】基于实例讲解PyTorch中的transforms类
【单点知识】基于实例讲解PyTorch中的transforms类
30 0
|
2月前
|
机器学习/深度学习 数据采集 PyTorch
【单点知识】基于实例讲解PyTorch中的ImageFolder类
【单点知识】基于实例讲解PyTorch中的ImageFolder类
28 0
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
51 0
|
3月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
2月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
91 2
|
3月前
|
机器学习/深度学习 算法 PyTorch
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习中卷积神经网络(CNN)的讲解及图像处理实战(超详细 附源码)
PyTorch深度学习中卷积神经网络(CNN)的讲解及图像处理实战(超详细 附源码)
130 0