SWA(随机权重平均) for Pytorch

简介: SWA(随机权重平均) for Pytorch

pytorch1.6中加入了随机权重平均(SWA)的api,使用起来更加方便了。

一.什么是Stochastic Weight Averaging(SWA)

SWA是使用修正后的学习率策略对SGD(或任何随机优化器)遍历的权重进行平均,从而可以得到更好的收敛效果。

随机梯度下降(SGD)在测试集上,趋向于收敛至损失相对低的地方,但却很难收敛至最低点,如上述左图中,经过几个epoch的训练,得到了W1,W2,W3三个权重,但无法收敛至最低点。如果使用SWA可以将三个权重加权平均,从而可能收敛至相对SGD更小的损失。

二.SWA与SGD的对比

从上面图中,可以发现,SGD在训练集收敛得比较好,但是在测试集效果并不如SWA。而SWA虽然在训练集收敛得不如SGD,但是在测试集上表现得更加好。下面得这张曲线图也可以看出两者的差异。

三.SWA大致的使用流程(pytorch)

上图是一种SWA的例子。先使用恒定学习率进行训练,接着线性衰减学习率,最后在恒定学习率上,累加它们的权重(SWA)。在使用SWA之前,可以配合任意的优化器使用,如SGD、Adam等,直到训练到一定周期,开始记录训练的权重,当训练完成后,再将记录的权重进行加权平均。注意:在训练的过程中是不进行预测的(下面的代码可以看到),直到最后训练完后,再加权,然后才开始预测。

from torch.optim.swa_utils import AveragedModel, SWALR
from torch.optim.lr_scheduler import CosineAnnealingLR
loader, optimizer, model, loss_fn = ...  # 定义数据加载器,优化器,模型,损失
swa_model = AveragedModel(model)  
scheduler = CosineAnnealingLR(optimizer, T_max=100) # 使用学习率策略(余弦退火)
swa_start = 5  # 设置SWA开始的周期,当epoch到该值的时候才开始记录模型的权重
swa_scheduler = SWALR(optimizer, swa_lr=0.05) # 当SWA开始的时候,使用的学习率策略
for epoch in range(100):
      for input, target in loader:
          optimizer.zero_grad()
          loss_fn(model(input), target).backward()
          optimizer.step()
      if epoch > swa_start:
          swa_model.update_parameters(model)
          swa_scheduler.step()
      else:
          scheduler.step()
# Update bn statistics for the swa_model at the end
torch.optim.swa_utils.update_bn(loader, swa_model)
# Use swa_model to make predictions on test data 
preds = swa_model(test_input)

可以看到 使用了分为两个阶段的学习率策略,可以自由调整,SWALR中可以加入学习率策略的比如线性,余弦退火等。

torch.optim.swa_utils.update_bn(loader, swa_model)这一步的目的:

  • BN层没有在训练结束时计算激活统计信息。我们可以通过使用SWA模型对这些数据进行一次向前传递来计算这些统计数据。


四.Pytorch上使用swa的一些问题:

pytorch - swa_model模型保存的问题





参考链接:

https://blog.csdn.net/leviopku/article/details/84037946

https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/

目录
打赏
0
0
0
0
7
分享
相关文章
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
基于PyTorch实战权重衰减——L2范数正则化方法(附代码)
522 0
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
通过实例学习Pytorch加载权重.load_state_dict()与保存权重.save()
125 0
在pytorch中,模型权重的精度会影响模型在cpu上的推理速度吗?
在用pytorch训练模型时发现,模型训练的eopch越多,保存模型时模型权重的精度越好,模型在cpu上的推理的速度越慢,是因为模型权重精度会影响推理速度吗?如何调整pytorch模型参数的精度?
772 0
如何用Pytorch加载部分权重
在我做实验的过程中,由于卷积神经网络层数的更改,导致原始网络模型的权重加载失败,经过分析,是因为不匹配造成的,如下方式可以解决.
310 0
基于昇腾用PyTorch实现传统CTR模型WideDeep网络
本文介绍了如何在昇腾平台上使用PyTorch实现经典的WideDeep网络模型,以处理推荐系统中的点击率(CTR)预测问题。
207 66
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
本文详细介绍了DeepSeek R1模型的构建过程,涵盖从基础模型选型到多阶段训练流程,再到关键技术如强化学习、拒绝采样和知识蒸馏的应用。
75 3
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
684 2

热门文章

最新文章

AI助理

你好,我是AI助理

可以解答问题、推荐解决方案等