加载模型时发生错误RuntimeError: Error(s) in loading state_dict for Net:unexpected key(s) in state_dict: XXX
Traceback (most recent call last): File "demo.py", line 380, in <module> model.load_state_dict(torch.load('./0428.pth')) File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1482, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for ViT: Unexpected key(s) in state_dict: "transformer.skipcat.3.weight", "transformer.skipcat.3.bias", "transformer.skipcat.4.weight", "transformer.skipcat.4.bias".
原因:
加载使用模型时和训练模型时的环境不一致.
解决方法:
将load_state_dict(state_dict) 改成 model.load_state_dict(state_dict, False)
model.load_state_dict(torch.load('models/params.pt'),strict=False)
问题解决~