PyTorch自定义学习率调度器实现指南

本文涉及的产品
实时数仓Hologres,5000CU*H 100GB 3个月
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,5000CU*H 3个月
简介: 本文将详细介绍如何通过扩展PyTorch的```LRScheduler```类来实现一个具有预热阶段的余弦衰减调度器。我们将分五个关键步骤来完成这个过程。

在深度学习训练过程中,学习率调度器扮演着至关重要的角色。这主要是因为在训练的不同阶段,模型的学习动态会发生显著变化。

在训练初期,损失函数通常呈现剧烈波动,梯度值较大且不稳定。此阶段的主要目标是在优化空间中快速接近某个局部最小值。然而,过高的学习率可能导致模型跳过潜在的优质局部最小值,从而限制了模型性能的充分发挥。

尽管PyTorch提供了多种预定义的学习率调度器,但在特定研究场景或需要更精细控制时,这些标准实现可能无法完全满足需求。在这种情况下,实现自定义学习率调度器成为了一个可行的解决方案。

本文将详细介绍如何通过扩展PyTorch的

LRScheduler

类来实现一个具有预热阶段的余弦衰减调度器。我们将分五个关键步骤来完成这个过程。

1、继承LRScheduler类

在PyTorch中实现自定义学习率调度器时,首先需要继承

torch.optim.lr_scheduler.LRScheduler

类。这个基类提供了管理学习率调度所需的核心功能。

通过继承

LRScheduler

,我们可以利用以下关键特性:

  1. self.optimizer:对优化器的引用,用于调整其学习率。
  2. self.base_lrs:存储优化器中所有参数组的初始学习率,可在自定义调度器中进行访问和修改。
  3. self.last_epoch:跟踪当前训练轮次,用于根据轮次数调整学习率。
  4. step()方法:在每个训练轮次后调用,用于自动更新学习率。
  5. 参数组处理:LRScheduler设计支持优化器中的多个参数组,允许对模型的不同部分应用不同的学习率调整策略。

以下是继承

LRScheduler

的基本代码结构:

 fromtorch.optim.lr_schedulerimportLRScheduler

 classCosineWarmupScheduler(LRScheduler):
     pass

通过继承

LRScheduler

,我们获得了上述所有功能,只需要通过实现

get_lr()

方法来定义学习率的具体变化逻辑。

2、实现构造函数

在自定义学习率调度器中,构造函数(

__init__

方法)用于初始化调度器的关键参数。这些参数定义了学习率调整的具体策略,包括预热期的长度、总训练轮次和最小学习率等。

以下是构造函数的实现示例:

 classCosineWarmupScheduler(LRScheduler):
     def__init__(self, optimizer, warmup_epochs, total_epochs, min_lr=0.0, last_epoch=-1):
         self.warmup_epochs=warmup_epochs    # 学习率线性增加的预热轮次
         self.total_epochs=total_epochs      # 总训练轮次
         self.min_lr=min_lr                  # 学习率下限
         super(CosineWarmupScheduler, self).__init__(optimizer, last_epoch)

参数说明:

  • optimizer:PyTorch优化器实例,其学习率将被调整。
  • warmup_epochs:预热阶段的轮次数,在此期间学习率线性增加。
  • total_epochs:训练的总轮次,包括预热阶段和衰减阶段。
  • min_lr:学习率的下限,衰减阶段的最终学习率不会低于此值。
  • last_epoch:上一轮的索引,用于恢复训练。默认为-1,表示从头开始训练。

3、调用父类构造函数

在自定义调度器的构造函数中,通过

super()

调用父类(

LRScheduler

)的构造函数是非常重要的。这确保了基类被正确初始化,使我们能够访问诸如

self.optimizer

self.base_lrs

self.last_epoch

等关键属性。

 super(CosineWarmupScheduler, self).__init__(optimizer, last_epoch)

这行代码不仅初始化了基类,还使得自定义调度器能够继承

LRScheduler

的其他有用方法,如

step()

get_last_lr()

4、实现get_lr()方法

get_lr()

方法是自定义调度器的核心,它定义了学习率如何随训练轮次变化的具体逻辑。在本例中,我们实现了一个包含预热阶段的余弦衰减调度策略:

预热阶段:在前

warmup_epochs

轮中,学习率从0线性增加到初始学习率。

余弦衰减阶段:预热结束后,学习率按余弦函数从初始值衰减到最小值。

以下是

get_lr()

方法的实现:

 importmath

 classCosineWarmupScheduler(LRScheduler):
     def__init__(self, optimizer, warmup_epochs, total_epochs, min_lr=0.0, last_epoch=-1):
         self.warmup_epochs=warmup_epochs
         self.total_epochs=total_epochs
         self.min_lr=min_lr
         super(CosineWarmupScheduler, self).__init__(optimizer, last_epoch)

     defget_lr(self):
         epoch=self.last_epoch+1

         ifepoch<=self.warmup_epochs:
             # 预热阶段:线性增加学习率
             return [base_lr*epoch/self.warmup_epochsforbase_lrinself.base_lrs]
         else:
             # 余弦衰减阶段
             decay_epochs=self.total_epochs-self.warmup_epochs
             cosine_decay=0.5* (1+math.cos(math.pi* (epoch-self.warmup_epochs) /decay_epochs))
             return [self.min_lr+ (base_lr-self.min_lr) *cosine_decayforbase_lrinself.base_lrs]

这个实现确保了学习率在预热阶段平滑增加,然后在剩余的训练过程中逐渐衰减,最终达到指定的最小值。

5、在训练流程中应用自定义调度器

实现自定义学习率调度器后,下一步是将其集成到训练流程中。以下示例展示了如何在PyTorch训练循环中初始化和使用自定义调度器:

 importtorch
 importtorch.optimasoptim

 # 定义模型(此处使用简单的线性模型作为示例)
 model=torch.nn.Linear(10, 1)

 # 初始化优化器
 optimizer=optim.SGD(model.parameters(), lr=0.1)

 # 初始化自定义学习率调度器
 scheduler=CosineWarmupScheduler(optimizer, warmup_epochs=5, total_epochs=50, min_lr=0.001)

 # 训练循环
 num_epochs=50

 forepochinrange(num_epochs):
     model.train()
     fordata, targetindataloader:
         optimizer.zero_grad()
         output=model(data)
         loss=criterion(output, target)
         loss.backward()
         optimizer.step()

     # 在每个epoch结束时更新学习率
     scheduler.step()

     # 记录当前学习率(用于监控)
     current_lr=scheduler.get_last_lr()[0]
     print(f"Epoch {epoch+1}/{num_epochs}, Learning Rate: {current_lr:.6f}")

在这个示例中,我们执行以下关键步骤:

  1. 定义模型和优化器。
  2. 使用之前实现的CosineWarmupScheduler初始化学习率调度器。
  3. 在每个训练epoch中:- 执行标准的前向传播、损失计算和反向传播步骤。- 调用optimizer.step()更新模型参数。- 在epoch结束时调用scheduler.step()更新学习率。
  4. 使用scheduler.get_last_lr()获取并记录当前学习率,用于监控训练过程。

关键组件说明

  • scheduler.step():这个方法在每个epoch结束时调用,根据当前epoch更新学习率。它是动态调整学习率的核心机制。
  • scheduler.get_last_lr():返回当前的学习率。在多参数组的情况下,它返回一个列表,每个元素对应一个参数组的学习率。

总结

通过继承PyTorch的

LRScheduler

类并实现自定义的

get_lr()

方法,我们可以创建灵活的学习率调度策略,以满足特定的训练需求。本指南展示的带预热的余弦衰减调度器只是众多可能实现的一个例子。

自定义学习率调度器的关键优势在于:

  1. 灵活性:可以实现任何所需的学习率调整策略。
  2. 精确控制:能够根据训练动态和模型特性精细调整学习过程。
  3. 适应性:可以轻松适应不同的模型架构和数据集特性。

在实际应用中,可能需要进行大量实验来确定最适合特定问题的学习率调度策略。通过掌握自定义调度器的实现技巧,研究人员和工程师可以更灵活地优化深度学习模型的训练过程,从而潜在地提高模型性能和训练效率。

https://avoid.overfit.cn/post/aa1e90e02eb24d9f982e2c933bdd97a7

目录
相关文章
|
8天前
|
弹性计算 人工智能 架构师
阿里云携手Altair共拓云上工业仿真新机遇
2024年9月12日,「2024 Altair 技术大会杭州站」成功召开,阿里云弹性计算产品运营与生态负责人何川,与Altair中国技术总监赵阳在会上联合发布了最新的“云上CAE一体机”。
阿里云携手Altair共拓云上工业仿真新机遇
|
4天前
|
机器学习/深度学习 算法 大数据
【BetterBench博士】2024 “华为杯”第二十一届中国研究生数学建模竞赛 选题分析
2024“华为杯”数学建模竞赛,对ABCDEF每个题进行详细的分析,涵盖风电场功率优化、WLAN网络吞吐量、磁性元件损耗建模、地理环境问题、高速公路应急车道启用和X射线脉冲星建模等多领域问题,解析了问题类型、专业和技能的需要。
2463 14
【BetterBench博士】2024 “华为杯”第二十一届中国研究生数学建模竞赛 选题分析
|
4天前
|
机器学习/深度学习 算法 数据可视化
【BetterBench博士】2024年中国研究生数学建模竞赛 C题:数据驱动下磁性元件的磁芯损耗建模 问题分析、数学模型、python 代码
2024年中国研究生数学建模竞赛C题聚焦磁性元件磁芯损耗建模。题目背景介绍了电能变换技术的发展与应用,强调磁性元件在功率变换器中的重要性。磁芯损耗受多种因素影响,现有模型难以精确预测。题目要求通过数据分析建立高精度磁芯损耗模型。具体任务包括励磁波形分类、修正斯坦麦茨方程、分析影响因素、构建预测模型及优化设计条件。涉及数据预处理、特征提取、机器学习及优化算法等技术。适合电气、材料、计算机等多个专业学生参与。
1502 14
【BetterBench博士】2024年中国研究生数学建模竞赛 C题:数据驱动下磁性元件的磁芯损耗建模 问题分析、数学模型、python 代码
|
1月前
|
运维 Cloud Native Devops
一线实战:运维人少,我们从 0 到 1 实践 DevOps 和云原生
上海经证科技有限公司为有效推进软件项目管理和开发工作,选择了阿里云云效作为 DevOps 解决方案。通过云效,实现了从 0 开始,到现在近百个微服务、数百条流水线与应用交付的全面覆盖,有效支撑了敏捷开发流程。
19274 29
|
1月前
|
人工智能 自然语言处理 搜索推荐
阿里云Elasticsearch AI搜索实践
本文介绍了阿里云 Elasticsearch 在AI 搜索方面的技术实践与探索。
18822 20
|
1月前
|
Rust Apache 对象存储
Apache Paimon V0.9最新进展
Apache Paimon V0.9 版本即将发布,此版本带来了多项新特性并解决了关键挑战。Paimon自2022年从Flink社区诞生以来迅速成长,已成为Apache顶级项目,并广泛应用于阿里集团内外的多家企业。
17515 13
Apache Paimon V0.9最新进展
|
6天前
|
编解码 JSON 自然语言处理
通义千问重磅开源Qwen2.5,性能超越Llama
击败Meta,阿里Qwen2.5再登全球开源大模型王座
365 11
|
1月前
|
存储 人工智能 前端开发
AI 网关零代码解决 AI 幻觉问题
本文主要介绍了 AI Agent 的背景,概念,探讨了 AI Agent 网关插件的使用方法,效果以及实现原理。
18697 16
|
2天前
|
算法 Java
JAVA并发编程系列(8)CountDownLatch核心原理
面试中的编程题目“模拟拼团”,我们通过使用CountDownLatch来实现多线程条件下的拼团逻辑。此外,深入解析了CountDownLatch的核心原理及其内部实现机制,特别是`await()`方法的具体工作流程。通过详细分析源码与内部结构,帮助读者更好地理解并发编程的关键概念。
|
2天前
|
SQL 监控 druid
Druid连接池学习
Druid学习笔记,使用Druid进行密码加密。参考文档:https://github.com/alibaba/druid
195 82