训练模型时,在众多训练好的模型中会有几个较好的模型,我们希望储存这些模型对应的参数值,避免后续难以训练出更好的结果,同时也方便我们复现这些模型,用于之后的研究。PyTorch提供了模型的保存与重载模块,包括torch.save()和torch.load(),以及pytorchtools中的EarlyStopping,这个模块就是用来解决上述的模型保存与重载问题
一、保存与重载模块
若希望保存/加载模型model的参数,而不保存/加载模型的结构,可以通过如下代码
其中state_dict是torch中的一个字典对象,将每一层与该层的对应参数张量建立映射关系
若希望同时保存/加载模型model的参数以及模型结构,而不保存/加载模型的结构,可以通过如下代码
为了获取性能良好的神经网络,训练网络的过程中需要进行许多对于模型各部分的设置,也就是超参数的调整。超参数之一就是训练周期(epoch),训练周期如果取值过小可能会导致欠拟合,取值过大可能会导致过拟合。为了避免训练周期设置不合适影响模型效果,EarlyStopping应运而生。EarlyStopping解决epoch需要手动设定的问题,也可以认为是一种避免网络发生过拟合的正则化方法
EarlyStopping的原理可以大致分为三个部分:
将原数据分为训练集和验证集;
只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差,如果随着周期的增加,在验证集上的测试误差也在增加,则停止训练;
将停止之后的权重作为网络的最终参数
初始化 early_stopping 对象:
EarlyStopping 对象的初始化包括三个参数,其含义如下:
patience(int) : 上次验证集损失值改善后等待几个epoch,默认值:7。
verbose(bool):如果值为True,为每个验证集损失值打印一条信息;若为False,则不打印,默认值:False。
delta(float):损失函数值改善的最小变化,当损失函数值的改善大于该值时,将会保存模型,默认值:0,即损失函数只要有改善即保存模型
定义一个函数,表示训练函数,希望通过 EarlyStopping 当测试集上的损失值有所下降时,将此时的信息打印出来,并且保存参数。 先创建将要用到的变量,以及初始化 earlystopping 对象
之后训练模型并保存损失值,计算每次迭代在训练集和测试集上的损失值得均值,并保存
调用 EarlyStopping 中的_call_()模块,判断损失值是否下降,若下降则进行保存,并打印信息
最后调用torch.load()加载最后一次的保存点,即最优模型,并返回模型,以及每轮迭代在训练集、测试集上的损失值的均值
二、可视化模块
在模型训练过程中,有时不仅需要保持和加载已经训练好的模型,也需要将训练过程中的训练集损失函数、验证集损失函数、模型计算图(即模型框架图、模型数据流图)等保持下来,供后续分析作图使用
例如,通过损失函数变化情况,可以观察模型是否收敛,通过模型计算图,可以观察数据流动情况等
Tensorboard可以将数据、模型计算图等进行可视化,会自动获取最新的数据信息,将其存入日志文件中,并且会在日志文件中更新信息,运行数据或模型最新的状态。Tensorboard中常用的模块包括如下七类
add_graph():添加网络结构图,将计算图可视化。
add_image()/add_images():添加单个图像数据/批量添加图像数据。
add_figure():添加matplotlib图片。
add_scalar()/add_scalars():添加一个标量/批量添加标量,在机器学习中可用于绘制损失函数。
add_histogram():添加统计分布直方图。
add_pr_curve():添加P-R(精准率-召回率)曲线。
add_txt():添加文字
Tensorboard的整体用法,参见下图
TensorBoard中可以使用add_graph()函数保存模型计算图,该函数用于在tensorboard中创建存放网络结构的Graphs,函数及其参数如下:
model(torch.nn.Module) 表示需要可视化的网络模型;
input_to_model(torch.Tensor or list of torch.Tensor)表示模型的输入变量,如果模型输入为多个变量,则用list或元组按顺序传入多个变量即可;
verbose(bool)为开关语句,控制是否在控制台中打印输出网络的图形结构
例如,有一个数据类型为torch.nn.Module的变量model,输入的张量为input1和input2,期望返回模型计算图,则可以输入如下代码,即可在SummaryWriter的日志文件夹中保存数据流图
PyTorch中SummaryWriter的输出文件夹一般为runs文件,保存的日志文件不可以直接双击打开,需要在cmd命令窗口中将目录导航到runs文件夹的上一级目录,并输入tensorboard –logdir runs即可打开日志文件,打开后复制链接到浏览器中,即可打开保存的模型计算图或数据变量等
TensorBoard中可以使用add_scalar()/add_scalars()函数保存一个或在一张图中保存多个常量,如训练损失函数值、测试损失函数值、或将训练损失函数值和测试损失函数值保存在一张图中。
add_scalar()函数及参数如下:
tag(string)为数据标识符;
scalar_value(float or string)为标量值,即希望保存的数值;
global_step(int)为全局步长值,可理解为x轴坐标
add_scalars()函数及参数如下:
main_tag(string)为主标识符,即tag的父级名称;
tag_scalar_dict(dict)为保存tag及tag对应的值的字典类型数据;
global_step(int)为全局步长值,可理解为x轴坐标。
add_scalars()可以批量添加标量,例如,绘制y=xsinx、y=xcosx、y=tanx的图像,可以输入如下代码,保存的日志文件打开方式与上文所述相同
创作不易 觉得有帮助请点赞关注收藏~~~