pytorch - swa_model模型保存的问题

简介: pytorch - swa_model模型保存的问题

AttributeError: Can‘t pickle local object ‘AveragedModel.__init__.<locals>.avg_fn‘

pytorch中使用swa_model 进行模型保存的时候会出现这样的问题。

torch.save(swa_model,"model.pt")

结果

AttributeError: Can‘t pickle local object ‘AveragedModel.__init__.<locals>.avg_fn

解决办法一:

用保存state_dict()代替保存整个模型

torch.save(swa_model.state_dict(),"model.pt")

这样就不会有报错了



后续加载模型:

swa_model.load_state_dict(torch.load("model.pt"))

解决办法二:

如果我们强行要保存整个模型。~~

根本原因是swa_utils.pyAveragedModel类将函数 avg_fn

class AveragedModel(Module):
    def __init__(self, model, device=None, avg_fn=None):
        super(AveragedModel, self).__init__()
        self.module = deepcopy(model)
        if device is not None:
            self.module = self.module.to(device)
        self.register_buffer('n_averaged',
                             torch.tensor(0, dtype=torch.long, device=device))
    #在这************************
        if avg_fn is None:
            def avg_fn(averaged_model_parameter, model_parameter, num_averaged):
                return averaged_model_parameter + \
                    (model_parameter - averaged_model_parameter) / (num_averaged + 1)
        self.avg_fn = avg_fn

定义到了__init__下面了,只要把它放到__init__外面,class内,就不会报错了。



1.首先复制一份swa_utils.py到当前目录

2.其次将复制的swa_utils.py中的AveragedModel类改成下面的样子

class AveragedModel(Module):
    def __init__(self, model, device=None):
        super(AveragedModel, self).__init__()
        self.module = deepcopy(model)
        if device is not None:
            self.module = self.module.to(device)
        self.register_buffer('n_averaged',
                             torch.tensor(0, dtype=torch.long, device=device))
    def forward(self, *args, **kwargs):
    ...
    def update_parameters(self, model):
    ...
    def avg_fn(self,averaged_model_parameter, model_parameter, num_averaged):
        return averaged_model_parameter + \
               (model_parameter - averaged_model_parameter) / (num_averaged + 1)

3.接着从新的swa_utils.py导入AveragedModel,就不会报错了。

在这两种办法中,我们显然会选择第一种,但是第二种办法也可能会经常遇到。

解决方法二可以适用于

AttributeError: Can‘t pickle local object ‘XXXXXX.__init__.<locals>.XXX 

等的报错


相关链接:

SWA(随机权重平均) for Pytorch

相关文章
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】26.卷积神经网络之AlexNet模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】26.卷积神经网络之AlexNet模型介绍及其Pytorch实现【含完整代码】
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
|
5天前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
18 1
|
1月前
|
机器学习/深度学习 算法 PyTorch
使用Pytorch中从头实现去噪扩散概率模型(DDPM)
在本文中,我们将构建基础的无条件扩散模型,即去噪扩散概率模型(DDPM)。从探究算法的直观工作原理开始,然后在PyTorch中从头构建它。本文主要关注算法背后的思想和具体实现细节。
8619 3
|
20天前
|
机器学习/深度学习 人工智能 PyTorch
人工智能平台PAI使用问题之如何布置一个PyTorch的模型
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
25天前
|
机器学习/深度学习 PyTorch 算法框架/工具
C++多态崩溃问题之在PyTorch中,如何定义一个简单的线性回归模型
C++多态崩溃问题之在PyTorch中,如何定义一个简单的线性回归模型
|
1月前
|
机器学习/深度学习 数据采集 PyTorch
使用 PyTorch 创建的多步时间序列预测的 Encoder-Decoder 模型
本文提供了一个用于解决 Kaggle 时间序列预测任务的 encoder-decoder 模型,并介绍了获得前 10% 结果所涉及的步骤。
37 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
Pytorch实现线性回归模型
在机器学习和深度学习领域,线性回归是一种基本且广泛应用的算法,它简单易懂但功能强大,常作为更复杂模型的基础。使用PyTorch实现线性回归,不仅帮助初学者理解模型概念,还为探索高级模型奠定了基础。代码示例中,`creat_data()` 函数生成线性回归数据,包括噪声,`linear_regression()` 定义了线性模型,`square_loss()` 计算损失,而 `sgd()` 实现了梯度下降优化。
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch中的模型创建(一)
最全最详细的PyTorch神经网络创建
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具