AttributeError: Can‘t pickle local object ‘AveragedModel.__init__.<locals>.avg_fn‘
在pytorch中使用swa_model 进行模型保存的时候会出现这样的问题。
torch.save(swa_model,"model.pt")
结果
AttributeError: Can‘t pickle local object ‘AveragedModel.__init__.<locals>.avg_fn
解决办法一:
用保存state_dict()代替保存整个模型
torch.save(swa_model.state_dict(),"model.pt")
这样就不会有报错了
后续加载模型:
swa_model.load_state_dict(torch.load("model.pt"))
解决办法二:
如果我们强行要保存整个模型。~~
根本原因是swa_utils.py
中AveragedModel
类将函数 avg_fn
class AveragedModel(Module): def __init__(self, model, device=None, avg_fn=None): super(AveragedModel, self).__init__() self.module = deepcopy(model) if device is not None: self.module = self.module.to(device) self.register_buffer('n_averaged', torch.tensor(0, dtype=torch.long, device=device)) #在这************************ if avg_fn is None: def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return averaged_model_parameter + \ (model_parameter - averaged_model_parameter) / (num_averaged + 1) self.avg_fn = avg_fn
定义到了__init__
下面了,只要把它放到__init__
外面,class
内,就不会报错了。
1.首先复制一份swa_utils.py
到当前目录
2.其次将复制的swa_utils.py
中的AveragedModel
类改成下面的样子
class AveragedModel(Module): def __init__(self, model, device=None): super(AveragedModel, self).__init__() self.module = deepcopy(model) if device is not None: self.module = self.module.to(device) self.register_buffer('n_averaged', torch.tensor(0, dtype=torch.long, device=device)) def forward(self, *args, **kwargs): ... def update_parameters(self, model): ... def avg_fn(self,averaged_model_parameter, model_parameter, num_averaged): return averaged_model_parameter + \ (model_parameter - averaged_model_parameter) / (num_averaged + 1)
3.接着从新的swa_utils.py
导入AveragedModel
,就不会报错了。
在这两种办法中,我们显然会选择第一种,但是第二种办法也可能会经常遇到。
解决方法二可以适用于
AttributeError: Can‘t pickle local object ‘XXXXXX.__init__.<locals>.XXX
等的报错
相关链接: