开发者社区 > 大数据与机器学习 > 人工智能平台PAI > 正文

请问优化PyTorch模型时如何从torchvision加载ResNet50模型?

已解决

请问优化PyTorch模型时如何从torchvision加载ResNet50模型?

展开
收起
felix@ 2023-01-29 12:40:27 734 0
3 条回答
写回答
取消 提交回答
  • 今天也要加油吖~
    采纳回答

    您好,由于PAI-Blade仅支持ScriptModule,因此需要转换模型格式,代码如下:

    model = models.resnet50().float().cuda()  # 准备模型。
    model = torch.jit.script(model).eval()    # 转换成ScriptModule。
    dummy = torch.rand(1, 3, 224, 224).cuda() # 构造测试数据。
    
    2023-01-29 12:46:09
    赞同 展开评论 打赏
  • 首先,需要在阿里云上安装PyTorch和torchvision,然后使用以下代码从torchvision加载ResNet50模型:

    from torchvision.models import resnet50
    model = resnet50(pretrained=True)
    
    2023-01-29 16:35:37
    赞同 展开评论 打赏
  • 由于与resnet50的分类数不一样,所以在调用时,要使用num_classes=分类数。 如果需要加载模型本身的参数,需要使用pretrained=True。由于最后一层的分类数不一样,所以最后一层的参数数目也就不一样,所以在加载模型参数时要去掉最后一层

    def _resnet(
        arch: str,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        pretrained: bool,
        progress: bool,
        **kwargs: Any
    ) -> ResNet:
        model = ResNet(block, layers, **kwargs)
        if pretrained:
            state_dict = load_state_dict_from_url(model_urls[arch],
                                                  progress=progress)
            
            for k in list(state_dict.keys()):  #固定遍历对象
                print(k)
                if k == "fc.weight" or k == "fc.bias":
                    state_dict.pop(k)  #删除最后一层的模型参数
             
            
            model.load_state_dict(state_dict,strict=False)  #非严格加载模型参数
        return model
    
    2023-01-29 12:46:10
    赞同 展开评论 打赏

热门讨论

热门文章

相关电子书

更多
低代码开发师(初级)实战教程 立即下载
冬季实战营第三期:MySQL数据库进阶实战 立即下载
阿里巴巴DevOps 最佳实践手册 立即下载