在我做实验的过程中,由于卷积神经网络层数的更改,导致原始网络模型的权重加载失败,经过分析,是因为不匹配造成的,如下方式可以解决.
import torch import models checkpoint = torch.load("./logs/01origial/model_best.pth") model = models.__dict__["vgg"](dataset="Beans", depth=16) #提取网络结结构,分别是数据集,网络的深度和每层的输出通道数 model.load_state_dict(checkpoint['state_dict']) model_10 = models.__dict__["vgg10"](dataset="Beans", depth=10) model_dict = model.state_dict() model_10_dict = model_10.state_dict() pretrained_dict = {k: v for k, v in model_dict.items() if k in model_10_dict.keys()} model_10_dict.update(pretrained_dict) model_10.load_state_dict(model_10_dict)