请问优化PyTorch模型时如何从torchvision加载ResNet50模型?
您好,由于PAI-Blade仅支持ScriptModule,因此需要转换模型格式,代码如下:
model = models.resnet50().float().cuda() # 准备模型。
model = torch.jit.script(model).eval() # 转换成ScriptModule。
dummy = torch.rand(1, 3, 224, 224).cuda() # 构造测试数据。
首先,需要在阿里云上安装PyTorch和torchvision,然后使用以下代码从torchvision加载ResNet50模型:
from torchvision.models import resnet50
model = resnet50(pretrained=True)
由于与resnet50的分类数不一样,所以在调用时,要使用num_classes=分类数。 如果需要加载模型本身的参数,需要使用pretrained=True。由于最后一层的分类数不一样,所以最后一层的参数数目也就不一样,所以在加载模型参数时要去掉最后一层
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any
) -> ResNet:
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
for k in list(state_dict.keys()): #固定遍历对象
print(k)
if k == "fc.weight" or k == "fc.bias":
state_dict.pop(k) #删除最后一层的模型参数
model.load_state_dict(state_dict,strict=False) #非严格加载模型参数
return model
版权声明:本文内容由阿里云实名注册用户自发贡献,版权归原作者所有,阿里云开发者社区不拥有其著作权,亦不承担相应法律责任。具体规则请查看《阿里云开发者社区用户服务协议》和《阿里云开发者社区知识产权保护指引》。如果您发现本社区中有涉嫌抄袭的内容,填写侵权投诉表单进行举报,一经查实,本社区将立刻删除涉嫌侵权内容。
人工智能平台 PAI(Platform for AI,原机器学习平台PAI)是面向开发者和企业的机器学习/深度学习工程平台,提供包含数据标注、模型构建、模型训练、模型部署、推理优化在内的AI开发全链路服务,内置140+种优化算法,具备丰富的行业场景插件,为用户提供低门槛、高性能的云原生AI工程化能力。