Pytorch 保存和提取训练好的神经网络

简介: Pytorch 保存和提取训练好的神经网络

在pytorch中,保存神经网络用方法:


torch.save(net, 'net.pkl')


提取神经网络用方法:


torch.load('net.pkl')


保存神经网络有两种方式:


1、保存整个网络


torch.save(net, 'net.pkl')


这种方法能最大程度的保留网络的所有信息,缺点是读取网络时速度稍慢


2、保存网络的状态信息


torch.save(net.state_dict(), 'net_params.pkl')


这种方法只保留网络当前的状态信息,保存和读取速度快,保存的pkl文件体积小,缺点是在读取网络时需要自行先构建网络,否则无法还原信息


示例:


import torch
import torch.nn.functional as F
from torch.autograd import Variable
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
x, y = Variable(x).cuda(), Variable(y).cuda()
# 保存网络
def save():
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    ).cuda()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()
    for t in range(300):
        prediction = net(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    plt.figure(1, figsize=(10,3))
    plt.subplot(131)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)
    # 保存整个网络
    torch.save(net, 'net.pkl')
    # 保存网络当前的状态
    torch.save(net.state_dict(), 'net_params.pkl')
# 提取整个网络
def restore_net():
    net = torch.load('net.pkl').cuda()
    prediction = net(x)
    plt.figure(1, figsize=(10, 3))
    plt.subplot(132)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)
# 提取网络状态
def restore_params():
    net = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1),
    ).cuda()
    net.load_state_dict(torch.load('net_params.pkl'))
    prediction = net(x)
    plt.figure(1, figsize=(10, 3))
    plt.subplot(133)
    plt.scatter(x.data.cpu().numpy(), y.data.cpu().numpy())
    plt.plot(x.data.cpu().numpy(), prediction.data.cpu().numpy(), 'r-', lw=5)
save()
restore_net()
restore_params()
plt.show()


0a2653c851af460fa595bd959398a8f1.png

图一为保存的神经网络,图二、三分别为用不同方法提取的神经网络,可以看到,两种提取方式的结果是一致的


相关文章
|
3月前
|
机器学习/深度学习 PyTorch 算法框架/工具
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
这篇文章介绍了如何使用PyTorch框架,结合CIFAR-10数据集,通过定义神经网络、损失函数和优化器,进行模型的训练和测试。
185 2
目标检测实战(一):CIFAR10结合神经网络加载、训练、测试完整步骤
|
3月前
|
机器学习/深度学习 数据可视化 计算机视觉
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
这篇文章详细介绍了如何通过可视化深度学习中每层特征层来理解网络的内部运作,并使用ResNet系列网络作为例子,展示了如何在训练过程中加入代码来绘制和保存特征图。
70 1
目标检测笔记(五):详细介绍并实现可视化深度学习中每层特征层的网络训练情况
|
14天前
|
机器学习/深度学习 人工智能 PyTorch
使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比
本文将系统阐述DPO的工作原理、实现机制,以及其与传统RLHF和SFT方法的本质区别。
68 22
使用PyTorch实现GPT-2直接偏好优化训练:DPO方法改进及其与监督微调的效果对比
|
2天前
|
机器学习/深度学习 算法 PyTorch
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
软演员-评论家算法(Soft Actor-Critic, SAC)是深度强化学习领域的重要进展,基于最大熵框架优化策略,在探索与利用之间实现动态平衡。SAC通过双Q网络设计和自适应温度参数,提升了训练稳定性和样本效率。本文详细解析了SAC的数学原理、网络架构及PyTorch实现,涵盖演员网络的动作采样与对数概率计算、评论家网络的Q值估计及其损失函数,并介绍了完整的SAC智能体实现流程。SAC在连续动作空间中表现出色,具有高样本效率和稳定的训练过程,适合实际应用场景。
20 7
深度强化学习中SAC算法:数学原理、网络架构及其PyTorch实现
|
18天前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本文详细介绍了如何在昇腾平台上使用PyTorch实现GraphSage算法,在CiteSeer数据集上进行图神经网络的分类训练。内容涵盖GraphSage的创新点、算法原理、网络架构及实战代码分析,通过采样和聚合方法高效处理大规模图数据。实验结果显示,模型在CiteSeer数据集上的分类准确率达到66.5%。
|
2月前
|
机器学习/深度学习 自然语言处理 语音技术
Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧
本文介绍了Python在深度学习领域的应用,重点讲解了神经网络的基础概念、基本结构、训练过程及优化技巧,并通过TensorFlow和PyTorch等库展示了实现神经网络的具体示例,涵盖图像识别、语音识别等多个应用场景。
73 8
|
3月前
|
机器学习/深度学习 数据采集 算法
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
这篇博客文章介绍了如何使用包含多个网络和多种训练策略的框架来完成多目标分类任务,涵盖了从数据准备到训练、测试和部署的完整流程,并提供了相关代码和配置文件。
73 0
目标分类笔记(一): 利用包含多个网络多种训练策略的框架来完成多目标分类任务(从数据准备到训练测试部署的完整流程)
|
3月前
|
机器学习/深度学习 算法 TensorFlow
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
学习率是深度学习中的关键超参数,它影响模型的训练进度和收敛性,过大或过小的学习率都会对网络训练产生负面影响,需要通过适当的设置和调整策略来优化。
606 0
深度学习笔记(五):学习率过大过小对于网络训练有何影响以及如何解决
|
3月前
|
机器学习/深度学习 算法
【机器学习】揭秘反向传播:深度学习中神经网络训练的奥秘
【机器学习】揭秘反向传播:深度学习中神经网络训练的奥秘
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】

热门文章

最新文章