聊一聊学习率预热linear warmup

简介: 聊一聊学习率预热linear warmup

什么是warmup


warmup是针对学习率learning rate优化的一种策略,主要过程是,在预热期间,学习率从0线性(也可非线性)增加到优化器中的初始预设lr,之后使其学习率从优化器中的初始lr线性降低到0。如下图所示:


73.png

image


warmup的作用


由于刚开始训练时,模型的权重(weights)是随机初始化的,此时若选择一个较大的学习率,可能带来模型的不稳定(振荡),选择Warmup预热学习率的方式,可以使得开始训练的几个epoch或者一些step内学习率较小,在预热的小学习率下,模型可以慢慢趋于稳定,等模型相对稳定后再选择预先设置的学习率进行训练,使得模型收敛速度变得更快,模型效果更佳。


为什么warmup有效


这个问题目前还没有被充分证明,下面是来自知乎的回答解释:


https://www.zhihu.com/question/338066667

从理论层面上可以解释为:


  • 有助于减缓模型在初始阶段对mini-batch的提前过拟合现象,保持分布的平稳
  • 有助于保持模型深层的稳定性

从训练效果可以体现为:

  • 一开始神经网络输出比较random,loss比较大,容易不收敛,因此用小点的学习率, 学一丢丢,慢慢涨上去。
  • 梯度偏离真正较优的方向可能性比较大,那就走短一点错了还可以掰回来。


如何使用warmup


  • 实例1:warm_up_ratio 设置预热步数

from transformers import AdanW, get_linear_schedule_with_warmup
optimizer = AdamW(model.parameters(), lr=lr, eps=adam_epsilon)
len_dataset = 3821 # 可以根据pytorch中的len(Dataset)计算
epoch = 30
batch_size = 32
total_steps = (len_dataset // batch_size) * epoch if len_dataset % batch_size = 0 else (len_dataset // batch_size + 1) * epoch # 每一个epoch中有多少个step可以根据len(DataLoader)计算:total_steps = len(DataLoader) * epoch
warm_up_ratio = 0.1 # 定义要预热的step
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warm_up_ratio * total_steps, num_training_steps = total_steps)
......
optimizer.step()
scheduler.step()
optimizer.zero_grad()


  • 实例1:num_warmup_steps 设置预热步数

# training steps 的数量: [number of batches] x [number of epochs].
total_steps = len(train_dataloader) * epochs
# 设计 learning rate scheduler
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = 50, 
                                            num_training_steps = total_steps)


经验参数选择


一般可取训练steps的10%,参考BERT。这里可以根据具体任务进行调整,主要需要通过warmup来使得学习率可以适应不同的训练集合,另外我们也可以通过训练误差观察loss抖动的关键位置,找出合适的学习率


其他非线性warmp策略

def _get_scheduler(optimizer, scheduler: str, warmup_steps: int, t_total: int):
        """
        Returns the correct learning rate scheduler. Available scheduler: constantlr, warmupconstant, warmuplinear, warmupcosine, warmupcosinewithhardrestarts
        """
        scheduler = scheduler.lower()
        if scheduler == 'constantlr':
            return transformers.get_constant_schedule(optimizer)
        elif scheduler == 'warmupconstant':
            return transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps)
        elif scheduler == 'warmuplinear':
            return transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
        elif scheduler == 'warmupcosine':
            return transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
        elif scheduler == 'warmupcosinewithhardrestarts':
            return transformers.get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total)
        else:
            raise ValueError("Unknown scheduler {}".format(scheduler))


参考资料

相关文章
|
存储 缓存 算法
哈希函数:保护数据完整性的关键
哈希函数:保护数据完整性的关键
|
8月前
|
人工智能 算法 API
智能体IP操盘手:AI产业的下一个核心职业——从技术开发到智能体人格化运营的新趋势
随着大模型与云计算发展,AI智能体正从工具演变为具备人格的数字IP。本文探讨“智能体IP操盘手”这一新兴职业的崛起,涵盖技术开发、人格设计与商业运营,并分析阿里云如何赋能智能体产业化,推动教育与产业融合,开启数字经济新曲线。(238字)
|
7月前
|
人工智能 自然语言处理 算法
开发者视角的最新视频营销软件工具观察:关于算法合规、自动化工作流与商业场景适配的分析
当前, 短视频内容创作工具的发展路径呈现出明确的“AI+移动原生”倾向,旨在将复杂的视频生产全链路浓缩于手机端和加入自动化AI功能。这极大地降低了技术门槛与时间成本,以适应短视频营销高频、快反的本质需求。与此同时,合规性已成为企业级应用不可逾越的底线。采用能理解本土商业语境的文化算法,以及完成国家要求的大模型备案,是构建信任的关键。这确保了生成内容在法律层面的安全性,为企业规避了数据与版权纠纷风险,提供了长期运营的“安全护城河”。 更重要的是,新一代工具正从执行单点命令的“辅助工具”,进化为能理解商业意图并自主完成复杂任务的Agent。它能自动接管从需求解析、脚本生成到多模态素材合成的全流程
298 5
|
11月前
|
数据采集 人工智能 调度
传统IT企业如何在AI时代中找准定位、实现转型升级?—— 解析传统IT企业的AI转型策略
本文AI专家三桥君探讨传统IT企业在AI浪潮中的转型策略,提出从工具提供商向业务成果交付者的商业模式转变。核心观点包括:构建"操作系统式AI"技术架构、发展"智能体经济"组织模式、采用SMART策略实现高效部署。三桥君强调AI转型需商业模式、组织架构与技术体系的全面革新,为传统IT企业提供系统性转型框架。
655 0
|
前端开发 开发工具 Android开发
小红书APP的全新鸿蒙NEXT端性能优化技术实践
从 2023 年开始,鸿蒙的优势愈发明显,已经成为可与 iOS、安卓媲美的第三大移动操作系统。从一些抖音视频中也可以看出,鸿蒙在流畅性方面甚至在某些层面上超过了 iOS。本次分享的主题是小红书在鸿蒙平台上的工程实践,主要聚焦于性能优化和探索。
936 10
|
机器学习/深度学习 存储 自然语言处理
Adam介绍
【10月更文挑战第3天】
|
机器学习/深度学习 并行计算 数据可视化
cs224w(图机器学习)2021冬季课程学习笔记13 Colab 3
本colab主要实现: 实现GraphSAGE和GAT模型,应用在Cora数据集上。 使用DeepSNAP包切分图数据集、实现数据集转换,完成边属性预测(链接预测)任务。
cs224w(图机器学习)2021冬季课程学习笔记13 Colab 3
|
Java 应用服务中间件 开发者
Spring Boot 2.x新特性有哪些?
【7月更文挑战第16天】Spring Boot 2.x新特性有哪些?
556 1
|
机器学习/深度学习 数据可视化 计算机视觉
注意力机制BAM和CBAM详细解析(附代码)
注意力机制BAM和CBAM详细解析(附代码)
注意力机制BAM和CBAM详细解析(附代码)