PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)

简介: PyTorch基础之模型保存与重载模块、可视化模块讲解(附源码)

训练模型时,在众多训练好的模型中会有几个较好的模型,我们希望储存这些模型对应的参数值,避免后续难以训练出更好的结果,同时也方便我们复现这些模型,用于之后的研究。PyTorch提供了模型的保存与重载模块,包括torch.save()和torch.load(),以及pytorchtools中的EarlyStopping,这个模块就是用来解决上述的模型保存与重载问题

一、保存与重载模块

若希望保存/加载模型model的参数,而不保存/加载模型的结构,可以通过如下代码

其中state_dict是torch中的一个字典对象,将每一层与该层的对应参数张量建立映射关系

若希望同时保存/加载模型model的参数以及模型结构,而不保存/加载模型的结构,可以通过如下代码

为了获取性能良好的神经网络,训练网络的过程中需要进行许多对于模型各部分的设置,也就是超参数的调整。超参数之一就是训练周期(epoch),训练周期如果取值过小可能会导致欠拟合,取值过大可能会导致过拟合。为了避免训练周期设置不合适影响模型效果,EarlyStopping应运而生。EarlyStopping解决epoch需要手动设定的问题,也可以认为是一种避免网络发生过拟合的正则化方法

EarlyStopping的原理可以大致分为三个部分:

将原数据分为训练集和验证集;

只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差,如果随着周期的增加,在验证集上的测试误差也在增加,则停止训练;

将停止之后的权重作为网络的最终参数

初始化 early_stopping 对象:

EarlyStopping 对象的初始化包括三个参数,其含义如下:

patience(int) : 上次验证集损失值改善后等待几个epoch,默认值:7。

verbose(bool):如果值为True,为每个验证集损失值打印一条信息;若为False,则不打印,默认值:False。

delta(float):损失函数值改善的最小变化,当损失函数值的改善大于该值时,将会保存模型,默认值:0,即损失函数只要有改善即保存模型

定义一个函数,表示训练函数,希望通过 EarlyStopping 当测试集上的损失值有所下降时,将此时的信息打印出来,并且保存参数。 先创建将要用到的变量,以及初始化 earlystopping 对象

之后训练模型并保存损失值,计算每次迭代在训练集和测试集上的损失值得均值,并保存

调用 EarlyStopping 中的_call_()模块,判断损失值是否下降,若下降则进行保存,并打印信息

最后调用torch.load()加载最后一次的保存点,即最优模型,并返回模型,以及每轮迭代在训练集、测试集上的损失值的均值

二、可视化模块

在模型训练过程中,有时不仅需要保持和加载已经训练好的模型,也需要将训练过程中的训练集损失函数、验证集损失函数、模型计算图(即模型框架图、模型数据流图)等保持下来,供后续分析作图使用

例如,通过损失函数变化情况,可以观察模型是否收敛,通过模型计算图,可以观察数据流动情况等

Tensorboard可以将数据、模型计算图等进行可视化,会自动获取最新的数据信息,将其存入日志文件中,并且会在日志文件中更新信息,运行数据或模型最新的状态。Tensorboard中常用的模块包括如下七类

add_graph():添加网络结构图,将计算图可视化。

add_image()/add_images():添加单个图像数据/批量添加图像数据。

add_figure():添加matplotlib图片。

add_scalar()/add_scalars():添加一个标量/批量添加标量,在机器学习中可用于绘制损失函数。

add_histogram():添加统计分布直方图。

add_pr_curve():添加P-R(精准率-召回率)曲线。  

add_txt():添加文字

Tensorboard的整体用法,参见下图

TensorBoard中可以使用add_graph()函数保存模型计算图,该函数用于在tensorboard中创建存放网络结构的Graphs,函数及其参数如下:

model(torch.nn.Module) 表示需要可视化的网络模型;

input_to_model(torch.Tensor or list of torch.Tensor)表示模型的输入变量,如果模型输入为多个变量,则用list或元组按顺序传入多个变量即可;

verbose(bool)为开关语句,控制是否在控制台中打印输出网络的图形结构

例如,有一个数据类型为torch.nn.Module的变量model,输入的张量为input1和input2,期望返回模型计算图,则可以输入如下代码,即可在SummaryWriter的日志文件夹中保存数据流图

PyTorch中SummaryWriter的输出文件夹一般为runs文件,保存的日志文件不可以直接双击打开,需要在cmd命令窗口中将目录导航到runs文件夹的上一级目录,并输入tensorboard –logdir runs即可打开日志文件,打开后复制链接到浏览器中,即可打开保存的模型计算图或数据变量等

TensorBoard中可以使用add_scalar()/add_scalars()函数保存一个或在一张图中保存多个常量,如训练损失函数值、测试损失函数值、或将训练损失函数值和测试损失函数值保存在一张图中。

add_scalar()函数及参数如下:

 

tag(string)为数据标识符;

scalar_value(float or string)为标量值,即希望保存的数值;

global_step(int)为全局步长值,可理解为x轴坐标

add_scalars()函数及参数如下:

main_tag(string)为主标识符,即tag的父级名称;

tag_scalar_dict(dict)为保存tag及tag对应的值的字典类型数据;

global_step(int)为全局步长值,可理解为x轴坐标。

add_scalars()可以批量添加标量,例如,绘制y=xsinx、y=xcosx、y=tanx的图像,可以输入如下代码,保存的日志文件打开方式与上文所述相同

创作不易 觉得有帮助请点赞关注收藏~~~

相关实践学习
日志服务之使用Nginx模式采集日志
本文介绍如何通过日志服务控制台创建Nginx模式的Logtail配置快速采集Nginx日志并进行多维度分析。
相关文章
|
2月前
|
机器学习/深度学习 自然语言处理 PyTorch
【PyTorch实战演练】基于AlexNet的预训练模型介绍
【PyTorch实战演练】基于AlexNet的预训练模型介绍
85 0
|
3月前
|
机器学习/深度学习 并行计算 PyTorch
TensorRT部署系列 | 如何将模型从 PyTorch 转换为 TensorRT 并加速推理?
TensorRT部署系列 | 如何将模型从 PyTorch 转换为 TensorRT 并加速推理?
161 0
|
1月前
|
机器学习/深度学习 关系型数据库 MySQL
大模型中常用的注意力机制GQA详解以及Pytorch代码实现
GQA是一种结合MQA和MHA优点的注意力机制,旨在保持MQA的速度并提供MHA的精度。它将查询头分成组,每组共享键和值。通过Pytorch和einops库,可以简洁实现这一概念。GQA在保持高效性的同时接近MHA的性能,是高负载系统优化的有力工具。相关论文和非官方Pytorch实现可进一步探究。
105 4
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
Pytorch使用VGG16模型进行预测猫狗二分类
深度学习已经在计算机视觉领域取得了巨大的成功,特别是在图像分类任务中。VGG16是深度学习中经典的卷积神经网络(Convolutional Neural Network,CNN)之一,由牛津大学的Karen Simonyan和Andrew Zisserman在2014年提出。VGG16网络以其深度和简洁性而闻名,是图像分类中的重要里程碑。
|
2月前
|
机器学习/深度学习 数据采集 PyTorch
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
使用PyTorch解决多分类问题:构建、训练和评估深度学习模型
|
14天前
|
PyTorch 算法框架/工具 Python
【pytorch框架】对模型知识的基本了解
【pytorch框架】对模型知识的基本了解
|
24天前
|
机器学习/深度学习 算法 PyTorch
PyTorch模型优化与调优:正则化、批归一化等技巧
【4月更文挑战第18天】本文探讨了PyTorch中提升模型性能的优化技巧,包括正则化(L1/L2正则化、Dropout)、批归一化、学习率调整策略和模型架构优化。正则化防止过拟合,Dropout提高泛化能力;批归一化加速训练并提升性能;学习率调整策略动态优化训练效果;模型架构优化涉及网络结构和参数的调整。这些方法有助于实现更高效的深度学习模型。
|
24天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
26天前
|
机器学习/深度学习 数据可视化 PyTorch
PyTorch小技巧:使用Hook可视化网络层激活(各层输出)
这篇文章将演示如何可视化PyTorch激活层。可视化激活,即模型内各层的输出,对于理解深度神经网络如何处理视觉信息至关重要,这有助于诊断模型行为并激发改进。
19 1
|
2月前
|
PyTorch 算法框架/工具 Python
Pytorch构建网络模型时super(__class__, self).__init__()的作用
Pytorch构建网络模型时super(__class__, self).__init__()的作用
12 0