SWA(随机权重平均) for Pytorch

简介: SWA(随机权重平均) for Pytorch

pytorch1.6中加入了随机权重平均(SWA)的api,使用起来更加方便了。

一.什么是Stochastic Weight Averaging(SWA)

SWA是使用修正后的学习率策略对SGD(或任何随机优化器)遍历的权重进行平均,从而可以得到更好的收敛效果。

随机梯度下降(SGD)在测试集上,趋向于收敛至损失相对低的地方,但却很难收敛至最低点,如上述左图中,经过几个epoch的训练,得到了W1,W2,W3三个权重,但无法收敛至最低点。如果使用SWA可以将三个权重加权平均,从而可能收敛至相对SGD更小的损失。

二.SWA与SGD的对比

从上面图中,可以发现,SGD在训练集收敛得比较好,但是在测试集效果并不如SWA。而SWA虽然在训练集收敛得不如SGD,但是在测试集上表现得更加好。下面得这张曲线图也可以看出两者的差异。

三.SWA大致的使用流程(pytorch)

上图是一种SWA的例子。先使用恒定学习率进行训练,接着线性衰减学习率,最后在恒定学习率上,累加它们的权重(SWA)。在使用SWA之前,可以配合任意的优化器使用,如SGD、Adam等,直到训练到一定周期,开始记录训练的权重,当训练完成后,再将记录的权重进行加权平均。注意:在训练的过程中是不进行预测的(下面的代码可以看到),直到最后训练完后,再加权,然后才开始预测。

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
loader, optimizer, model, loss_fn = ...  # 定义数据加载器,优化器,模型,损失
swa_model = AveragedModel(model)  
scheduler = CosineAnnealingLR(optimizer, T_max=100) # 使用学习率策略(余弦退火)
swa_start = 5  # 设置SWA开始的周期,当epoch到该值的时候才开始记录模型的权重
swa_scheduler = SWALR(optimizer, swa_lr=0.05) # 当SWA开始的时候,使用的学习率策略
for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

可以看到 使用了分为两个阶段的学习率策略,可以自由调整,SWALR中可以加入学习率策略的比如线性,余弦退火等。

torch.optim.swa_utils.update_bn(loader, swa_model)这一步的目的:

  • BN层没有在训练结束时计算激活统计信息。我们可以通过使用SWA模型对这些数据进行一次向前传递来计算这些统计数据。


四.Pytorch上使用swa的一些问题:

pytorch - swa_model模型保存的问题





参考链接:

https://blog.csdn.net/leviopku/article/details/84037946

https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/

相关文章
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
463 0
|
7月前
|
机器学习/深度学习 PyTorch 算法框架/工具
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
98 0
|
PyTorch 算法框架/工具
【PyTorch】初始化网络各层权重
【PyTorch】初始化网络各层权重
68 0
|
PyTorch 算法框架/工具
在pytorch中,模型权重的精度会影响模型在cpu上的推理速度吗?
在用pytorch训练模型时发现,模型训练的eopch越多,保存模型时模型权重的精度越好,模型在cpu上的推理的速度越慢,是因为模型权重精度会影响推理速度吗?如何调整pytorch模型参数的精度?
739 0
|
机器学习/深度学习 PyTorch 算法框架/工具
如何用Pytorch加载部分权重
在我做实验的过程中,由于卷积神经网络层数的更改,导致原始网络模型的权重加载失败,经过分析,是因为不匹配造成的,如下方式可以解决.
295 0
|
PyTorch 算法框架/工具
pytorch权重初始化
我们定义的网络如下所示
129 0
|
机器学习/深度学习 PyTorch 算法框架/工具
|
2月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
397 2
|
25天前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
42 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
2月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
79 8
利用 PyTorch Lightning 搭建一个文本分类模型