PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)

本文涉及的产品
日志服务 SLS,月写入数据量 50GB 1个月
简介: PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)

训练模型时,在众多训练好的模型中会有几个较好的模型,我们希望储存这些模型对应的参数值,避免后续难以训练出更好的结果,同时也方便我们复现这些模型,用于之后的研究。PyTorch提供了模型的保存与重载模块,包括torch.save()和torch.load(),以及pytorchtools中的EarlyStopping,这个模块就是用来解决上述的模型保存与重载问题

一、保存与重载模块

若希望保存/加载模型model的参数,而不保存/加载模型的结构,可以通过如下代码

其中state_dict是torch中的一个字典对象,将每一层与该层的对应参数张量建立映射关系

若希望同时保存/加载模型model的参数以及模型结构,而不保存/加载模型的结构,可以通过如下代码

为了获取性能良好的神经网络,训练网络的过程中需要进行许多对于模型各部分的设置,也就是超参数的调整。超参数之一就是训练周期(epoch),训练周期如果取值过小可能会导致欠拟合,取值过大可能会导致过拟合。为了避免训练周期设置不合适影响模型效果,EarlyStopping应运而生。EarlyStopping解决epoch需要手动设定的问题,也可以认为是一种避免网络发生过拟合的正则化方法

EarlyStopping的原理可以大致分为三个部分:

将原数据分为训练集和验证集;

只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差,如果随着周期的增加,在验证集上的测试误差也在增加,则停止训练;

将停止之后的权重作为网络的最终参数

初始化 early_stopping 对象:

EarlyStopping 对象的初始化包括三个参数,其含义如下:

patience(int) : 上次验证集损失值改善后等待几个epoch,默认值:7。

verbose(bool):如果值为True,为每个验证集损失值打印一条信息;若为False,则不打印,默认值:False。

delta(float):损失函数值改善的最小变化,当损失函数值的改善大于该值时,将会保存模型,默认值:0,即损失函数只要有改善即保存模型

定义一个函数,表示训练函数,希望通过 EarlyStopping 当测试集上的损失值有所下降时,将此时的信息打印出来,并且保存参数。 先创建将要用到的变量,以及初始化 earlystopping 对象

之后训练模型并保存损失值,计算每次迭代在训练集和测试集上的损失值得均值,并保存

调用 EarlyStopping 中的_call_()模块,判断损失值是否下降,若下降则进行保存,并打印信息

最后调用torch.load()加载最后一次的保存点,即最优模型,并返回模型,以及每轮迭代在训练集、测试集上的损失值的均值

二、可视化模块

在模型训练过程中,有时不仅需要保持和加载已经训练好的模型,也需要将训练过程中的训练集损失函数、验证集损失函数、模型计算图(即模型框架图、模型数据流图)等保持下来,供后续分析作图使用

例如,通过损失函数变化情况,可以观察模型是否收敛,通过模型计算图,可以观察数据流动情况等

Tensorboard可以将数据、模型计算图等进行可视化,会自动获取最新的数据信息,将其存入日志文件中,并且会在日志文件中更新信息,运行数据或模型最新的状态。Tensorboard中常用的模块包括如下七类

add_graph():添加网络结构图,将计算图可视化。

add_image()/add_images():添加单个图像数据/批量添加图像数据。

add_figure():添加matplotlib图片。

add_scalar()/add_scalars():添加一个标量/批量添加标量,在机器学习中可用于绘制损失函数。

add_histogram():添加统计分布直方图。

add_pr_curve():添加P-R(精准率-召回率)曲线。  

add_txt():添加文字

Tensorboard的整体用法,参见下图

TensorBoard中可以使用add_graph()函数保存模型计算图,该函数用于在tensorboard中创建存放网络结构的Graphs,函数及其参数如下:

model(torch.nn.Module) 表示需要可视化的网络模型;

input_to_model(torch.Tensor or list of torch.Tensor)表示模型的输入变量,如果模型输入为多个变量,则用list或元组按顺序传入多个变量即可;

verbose(bool)为开关语句,控制是否在控制台中打印输出网络的图形结构

例如,有一个数据类型为torch.nn.Module的变量model,输入的张量为input1和input2,期望返回模型计算图,则可以输入如下代码,即可在SummaryWriter的日志文件夹中保存数据流图

PyTorch中SummaryWriter的输出文件夹一般为runs文件,保存的日志文件不可以直接双击打开,需要在cmd命令窗口中将目录导航到runs文件夹的上一级目录,并输入tensorboard –logdir runs即可打开日志文件,打开后复制链接到浏览器中,即可打开保存的模型计算图或数据变量等

TensorBoard中可以使用add_scalar()/add_scalars()函数保存一个或在一张图中保存多个常量,如训练损失函数值、测试损失函数值、或将训练损失函数值和测试损失函数值保存在一张图中。

add_scalar()函数及参数如下:

 

tag(string)为数据标识符;

scalar_value(float or string)为标量值,即希望保存的数值;

global_step(int)为全局步长值,可理解为x轴坐标

add_scalars()函数及参数如下:

main_tag(string)为主标识符,即tag的父级名称;

tag_scalar_dict(dict)为保存tag及tag对应的值的字典类型数据;

global_step(int)为全局步长值,可理解为x轴坐标。

add_scalars()可以批量添加标量,例如,绘制y=xsinx、y=xcosx、y=tanx的图像,可以输入如下代码,保存的日志文件打开方式与上文所述相同

创作不易 觉得有帮助请点赞关注收藏~~~

相关实践学习
日志服务之使用Nginx模式采集日志
本文介绍如何通过日志服务控制台创建Nginx模式的Logtail配置快速采集Nginx日志并进行多维度分析。
相关文章
|
1月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
153 2
|
1月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
55 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
1月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
51 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
2月前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
131 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
|
2月前
|
机器学习/深度学习 监控 PyTorch
PyTorch 模型调试与故障排除指南
在深度学习领域,PyTorch 成为开发和训练神经网络的主要框架之一。本文为 PyTorch 开发者提供全面的调试指南,涵盖从基础概念到高级技术的内容。目标读者包括初学者、中级开发者和高级工程师。本文探讨常见问题及解决方案,帮助读者理解 PyTorch 的核心概念、掌握调试策略、识别性能瓶颈,并通过实际案例获得实践经验。无论是在构建简单神经网络还是复杂模型,本文都将提供宝贵的洞察和实用技巧,帮助开发者更高效地开发和优化 PyTorch 模型。
40 3
PyTorch 模型调试与故障排除指南
|
1月前
|
PyTorch 算法框架/工具
Pytorch学习笔记(一):torch.cat()模块的详解
这篇博客文章详细介绍了Pytorch中的torch.cat()函数,包括其定义、使用方法和实际代码示例,用于将两个或多个张量沿着指定维度进行拼接。
66 0
Pytorch学习笔记(一):torch.cat()模块的详解
|
1月前
|
存储 并行计算 PyTorch
探索PyTorch:模型的定义和保存方法
探索PyTorch:模型的定义和保存方法
|
1月前
|
机器学习/深度学习 算法 PyTorch
Pytorch的常用模块和用途说明
肆十二在B站分享PyTorch常用模块及其用途,涵盖核心库torch、神经网络库torch.nn、优化库torch.optim、数据加载工具torch.utils.data、计算机视觉库torchvision等,适合深度学习开发者参考学习。链接:[肆十二-哔哩哔哩](https://space.bilibili.com/161240964)
29 0
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
探索PyTorch:自动微分模块
探索PyTorch:自动微分模块
|
1月前
|
并行计算 开发工具 异构计算
在Windows平台使用源码编译和安装PyTorch3D指定版本
【10月更文挑战第6天】在 Windows 平台上,编译和安装指定版本的 PyTorch3D 需要先安装 Python、Visual Studio Build Tools 和 CUDA(如有需要),然后通过 Git 获取源码。建议创建虚拟环境以隔离依赖,并使用 `pip` 安装所需库。最后,在源码目录下运行 `python setup.py install` 进行编译和安装。完成后即可在 Python 中导入 PyTorch3D 使用。
154 0

热门文章

最新文章