pytorch 模型保存与加载

简介: pytorch 模型保存与加载

 一、模型保存有两种形式:保存整体模型(包括模型结构和参数)、只保存模型参数

import torch
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
# 保存整体模型
output_dir = 'checkpoint.ckp'
model_to_save = model.module if hasattr(model, "module") else model
torch.save(model_to_save, output_dir)
# 加载整体模型
model = torch.load(output_dir, map_location=device)
 
# =========================================================================
# 只保存模型参数 state_dict
torch.save(model.state_dict(), output_dir)
# 只加载模型参数 state_dict
model.load_state_dict(torch.load(output_dir, map_location=device))

二、当训练的代码中使用了“torch.nn.DataParallel()”,这个命令是将网络在多块gpu中进行训练然后合并,这时采用上面“只保存模型参数”的方式时,保存的参数key中会在最前面多一个module.

解决方式有三个:

1.加载模型时去掉key中的module.

model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pth').items()})

2.加载时也用“torch.nn.DataParallel()

    if cuda:
        g_model = torch.nn.DataParallel(g_model)
        cudnn.benchmark = True
        g_model = g_model.cuda()
 
    if os.path.exists(model_path):
        print('Loading weights into state dict...')
        model_dict = g_model.state_dict()
        pretrained_dict = torch.load(model_path, map_location=device)
        g_model.load_state_dict(pretrained_dict)
        print('Finished!')

3.只有一个GPU就没有必要使用“torch.nn.DataParallel()”了

相关文章
|
6天前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
18 1
|
1月前
|
机器学习/深度学习 算法 PyTorch
使用Pytorch中从头实现去噪扩散概率模型(DDPM)
在本文中,我们将构建基础的无条件扩散模型,即去噪扩散概率模型(DDPM)。从探究算法的直观工作原理开始,然后在PyTorch中从头构建它。本文主要关注算法背后的思想和具体实现细节。
8619 3
|
21天前
|
机器学习/深度学习 人工智能 PyTorch
人工智能平台PAI使用问题之如何布置一个PyTorch的模型
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。
|
25天前
|
机器学习/深度学习 PyTorch 算法框架/工具
C++多态崩溃问题之在PyTorch中,如何定义一个简单的线性回归模型
C++多态崩溃问题之在PyTorch中,如何定义一个简单的线性回归模型
|
1月前
|
机器学习/深度学习 数据采集 PyTorch
使用 PyTorch 创建的多步时间序列预测的 Encoder-Decoder 模型
本文提供了一个用于解决 Kaggle 时间序列预测任务的 encoder-decoder 模型,并介绍了获得前 10% 结果所涉及的步骤。
37 0
|
2月前
|
机器学习/深度学习 算法 PyTorch
Pytorch实现线性回归模型
在机器学习和深度学习领域,线性回归是一种基本且广泛应用的算法,它简单易懂但功能强大,常作为更复杂模型的基础。使用PyTorch实现线性回归,不仅帮助初学者理解模型概念,还为探索高级模型奠定了基础。代码示例中,`creat_data()` 函数生成线性回归数据,包括噪声,`linear_regression()` 定义了线性模型,`square_loss()` 计算损失,而 `sgd()` 实现了梯度下降优化。
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch中的模型创建(一)
最全最详细的PyTorch神经网络创建
|
2月前
|
机器学习/深度学习 PyTorch 算法框架/工具
|
1月前
|
机器学习/深度学习 PyTorch TensorFlow
在深度学习中,数据增强是一种常用的技术,用于通过增加训练数据的多样性来提高模型的泛化能力。`albumentations`是一个强大的Python库,用于图像增强,支持多种图像变换操作,并且可以与深度学习框架(如PyTorch、TensorFlow等)无缝集成。
在深度学习中,数据增强是一种常用的技术,用于通过增加训练数据的多样性来提高模型的泛化能力。`albumentations`是一个强大的Python库,用于图像增强,支持多种图像变换操作,并且可以与深度学习框架(如PyTorch、TensorFlow等)无缝集成。
|
2月前
|
机器学习/深度学习 自然语言处理 PyTorch
【从零开始学习深度学习】48.Pytorch_NLP实战案例:如何使用预训练的词向量模型求近义词和类比词
【从零开始学习深度学习】48.Pytorch_NLP实战案例:如何使用预训练的词向量模型求近义词和类比词