SWA(随机权重平均) for Pytorch

简介: SWA(随机权重平均) for Pytorch

pytorch1.6中加入了随机权重平均(SWA)的api,使用起来更加方便了。

一.什么是Stochastic Weight Averaging(SWA)

SWA是使用修正后的学习率策略对SGD(或任何随机优化器)遍历的权重进行平均,从而可以得到更好的收敛效果。

随机梯度下降(SGD)在测试集上,趋向于收敛至损失相对低的地方,但却很难收敛至最低点,如上述左图中,经过几个epoch的训练,得到了W1,W2,W3三个权重,但无法收敛至最低点。如果使用SWA可以将三个权重加权平均,从而可能收敛至相对SGD更小的损失。

二.SWA与SGD的对比

从上面图中,可以发现,SGD在训练集收敛得比较好,但是在测试集效果并不如SWA。而SWA虽然在训练集收敛得不如SGD,但是在测试集上表现得更加好。下面得这张曲线图也可以看出两者的差异。

三.SWA大致的使用流程(pytorch)

上图是一种SWA的例子。先使用恒定学习率进行训练,接着线性衰减学习率,最后在恒定学习率上,累加它们的权重(SWA)。在使用SWA之前,可以配合任意的优化器使用,如SGD、Adam等,直到训练到一定周期,开始记录训练的权重,当训练完成后,再将记录的权重进行加权平均。注意:在训练的过程中是不进行预测的(下面的代码可以看到),直到最后训练完后,再加权,然后才开始预测。

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
loader, optimizer, model, loss_fn = ...  # 定义数据加载器,优化器,模型,损失
swa_model = AveragedModel(model)  
scheduler = CosineAnnealingLR(optimizer, T_max=100) # 使用学习率策略(余弦退火)
swa_start = 5  # 设置SWA开始的周期,当epoch到该值的时候才开始记录模型的权重
swa_scheduler = SWALR(optimizer, swa_lr=0.05) # 当SWA开始的时候,使用的学习率策略
for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

可以看到 使用了分为两个阶段的学习率策略,可以自由调整,SWALR中可以加入学习率策略的比如线性,余弦退火等。

torch.optim.swa_utils.update_bn(loader, swa_model)这一步的目的:

  • BN层没有在训练结束时计算激活统计信息。我们可以通过使用SWA模型对这些数据进行一次向前传递来计算这些统计数据。


四.Pytorch上使用swa的一些问题:

pytorch - swa_model模型保存的问题





参考链接:

https://blog.csdn.net/leviopku/article/details/84037946

https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/

相关文章
|
机器学习/深度学习 PyTorch 算法框架/工具
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
648 0
|
机器学习/深度学习 PyTorch 算法框架/工具
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
253 0
|
PyTorch 算法框架/工具
【PyTorch】初始化网络各层权重
【PyTorch】初始化网络各层权重
137 0
|
PyTorch 算法框架/工具
在pytorch中,模型权重的精度会影响模型在cpu上的推理速度吗?
在用pytorch训练模型时发现,模型训练的eopch越多,保存模型时模型权重的精度越好,模型在cpu上的推理的速度越慢,是因为模型权重精度会影响推理速度吗?如何调整pytorch模型参数的精度?
863 0
|
机器学习/深度学习 PyTorch 算法框架/工具
如何用Pytorch加载部分权重
在我做实验的过程中,由于卷积神经网络层数的更改,导致原始网络模型的权重加载失败,经过分析,是因为不匹配造成的,如下方式可以解决.
376 0
|
PyTorch 算法框架/工具
pytorch权重初始化
我们定义的网络如下所示
172 0
|
机器学习/深度学习 PyTorch 算法框架/工具
|
21天前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
74 1
|
5月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
684 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
21天前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
59 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节

热门文章

最新文章

推荐镜像

更多