IJCAI2023 | 高效训练Transformers的方法

简介: IJCAI2023 | 高效训练Transformers的方法

前言

论文:https://arxiv.org/pdf/2302.01107.pdf

深度学习是近年来最重要的方法之一,它彻底改变了机器学习和人工智能,并引领着第四次工业革命。训练GPT-3(1750亿参数)需要355个GPU年,并且至少花费460万美元。

此外随着基于注意力机制的模型规模的指数增长,训练内存也在相应地增加。例如,最大的语言模型从2018年的BERT-large(3.45亿参数)发展到目前具有数百亿参数的模型。

本文总结了用于训练基于注意力机制的模型(即Transformers)的通用技术。通过技术创新和主要用例对这些技术进行了分类,总结并在它们之间的联系。


一、Computation Effciency

1. Optimization(优化器)

为了实现梯度下降的更快收敛速度,一个经典的解决方案是融合动量技术,其中每一步都是陡峭下降方向和最近迭代移位的组合,有助于在相关方向上加速梯度下降并减缓振荡。

里程碑式的工作包括用于凸优化的Nesterov的加速梯度 [Nesterov, 1983] 和用于非凸问题的带动量的近端梯度 [Li etal, 2017] 等。为了满足机器学习模型的大规模优化需求,主导性的优化器以随机方式设计。

带有动量的随机梯度下降(SGD)和自适应学习率估计方法Adam被广泛用于训练深度神经网络。根据经验,使用Adam来训练Transformers优于使用SGD。

默认情况下,Adamw是Transformers最广泛使用的优化器之一,它是Adam的一个变体,将L2正则化和权重衰减分离。

最近,谷歌搜索优化算法并发现了一种简单而有效的优化器,称为Lion。Lion只跟踪第一阶梯度的动量,其更新仅考虑方向并且每个参数的大小相同,这与像Adamw这样的自适应优化器非常不同。

2. Initialization(参数初始化)

良好的初始化对于稳定训练、启用更高的学习率、加速收敛并提高泛化能力至关重要。

  • Fixup:提出了适当调整标准初始化的方法,以确保适当的梯度范数,避免梯度爆炸或梯度消失。这使得可以在不添加归一化层的情况下训练具有超过10,000层的非常深的网络。
  • ReZero和SkipInit:简单地将每一层初始化为执行身份操作。它们在每个残差块的输出上添加一个可学习的缩放乘数:

  • T-Fixup:针对Transformers进行了定制,并分析了Adam优化器中早期更新的不稳定性,因为二阶动量的方差是无界的。因此,它采用了Fixup的尺度调整方案来初始化残差块。

3. Sparse training(稀疏训练)

稀疏训练的关键思想是直接训练稀疏子网络,而不是从头开始训练完整的网络,同时不损失准确性。最早的证明来自“中彩票假设”(Lottery Ticket Hypothesis,LTH),即一个密集、随机初始化的网络包含可以独立训练以匹配原始网络准确性的子网络(中彩票)。

考虑到这一点,后续的研究提出了更高训练效率的方法,大致可以分为三类:

  1. 在初始化时,通过测量连接在损失上的重要性一次性找到稀疏网络,消除了复杂的迭代优化计划;
  2. 通过低成本方案在Transformers的早期训练阶段识别中彩票,然后仅训练这些早期票直到收敛;
  3. 使用交替的剪枝和生长计划,在整个训练过程中动态更新模型的稀疏模式,适用于通用架构。

4. Overparameterization(过参数化)

参数化是指神经网络的可学习参数数量远超过训练样本的数量。

实践中观察到,过参数化能有效改善网络的收敛速度和泛化能力,尽管理论上的保证不充分。早期的研究证明,在线性神经网络中增加深度可以加速随机梯度下降(SGD)的收敛。后续的研究扩展到了双层非线性神经网络,并且证明在一定假设下,SGD可以在多项式时间内收敛到深度神经网络训练目标的全局最小值。

在泛化方面,理论上证明了足够过参数化的三层神经网络能够泛化到总体风险,并且存在一种有趣的性质:在随机初始化的高概率下,SGD训练轨迹上任一点的临近区域都存在一个准确的网络。这与剪枝理论(LTH)有着深刻的联系,因为它部分解释了为什么在稀疏训练中LTH依然有效,即由于过参数化,存在大量风险低的良好小型子网络。

在Transformer模型中,通过利用过参数化理论中的快速收敛和更好的泛化,设计了一种高效的训练流程:训练一个非常大的模型,然后进行早期停止并大幅度压缩,这与LTH类似。

5. Large batch training(大批量训练)

加速训练的另一种流行方法是使用大批量大小,每个时期提供较少的迭代次数,并更好地利用计算资源。

从统计学的角度来看,大批量训练减小了随机梯度估计的方差,因此需要调整可靠的步长以获得更好的收敛性。在卷积神经网络时代,使用学习率的线性缩放,在1小时内使用8,192的批量大小在ImageNet上训练ResNet-50。然后提出了更先进的步长估计方法。

广泛使用的方法有SGD 的 LARS 和Adam 的 LAMB,它们分别为ResNet 和Transformers 提出了使用层自适应学习率的方法。层自适应策略可以被公式化为:

其中 , , 分别是时间步 t 时的学习率、第 i 层的参数和基于动量的梯度, 是一个缩放函数。

6. Incremental learning(增量学习)

增量学习是将原始的具有挑战性的优化问题放宽为一系列易于优化的子问题,其中一个子问题的解决方案可以作为后续子问题的良好初始化,以规避训练困难,类似于退火。

一些工作提出通过逐步堆叠层来加速BERT预训练,从较小的模型正确初始化较大的模型。以相反的方向进行,通过层丢弃以随机深度来训练Transformers,逐渐增加沿着时间维度和深度维度的丢弃率。

对于ViT,AutoProg 提出使用神经结构搜索自动决定模型在渐进学习过程中何时、在何处以及以多大程度增长。一个关键的观察是,逐渐增加输入图像的分辨率(减小补丁大小)可以显著加速ViT的训练,与广泛已知的训练动态相一致,即在早期阶段专注于低频结构,而在后期阶段专注于高频语义。

二、Data Selection

1. Token masking(Token掩码)

Token掩码是自监督预训练任务中的主导方法,例如掩码语言建模(MLM)和掩码图像建模(MIM)。

Token掩码是随机掩盖一些输入标记,并训练模型以预测缺失的内容,例如词汇ID或像素,使用可见标记的上下文信息。由于压缩序列长度会二次减小计算和内存复杂性,跳过处理掩码标记对MLM和MIM带来了相当大的训练效率提升。

2. Importance sampling(重要性采样)

通过对数据进行的重要性采样,也被称为数据修剪,理论上保证了通过优先考虑信息丰富的训练示例来加速监督学习的随机梯度算法,主要受益于方差减小。

对于深度神经网络,估计每个样本的重要性的一种主要方式是使用梯度范数,使用不同的近似方法使计算这些范数变得可行。

三、Memory Efficiency

1. 模型尺寸和内存效率

使用跨设备并行计算训练大型深度神经网络是一种常见的实践以满足内存需求。基本上有两种范式:

  • 数据并行(DP)将数据的小批量分布到不同的设备上
  • 模型并行(MP)将模型的子图分配到多个工作器上。

对于DP,随着可用工作器的增加,批量大小接近线性缩放。在第2节中讨论的大批量训练是为此情况开发的。然而很明显DP具有高的通信/计算效率,但内存效率较差。当模型变大时,单个设备无法存储模型副本,梯度的同步通信可能阻碍DP的可扩展性。

2. Quantized training(量化训练)

标准的神经网络训练例程采用全精度(即FP32)。相反,量化训练通过将激活/权重/梯度压缩为低位值(例如FP16或INT8)从头开始以降低精度进行神经网络训练。先前的研究已经表明,减小精度训练可以加速神经网络的训练,并具有良好的性能。

对于Transformers,最广泛采用的方法是自动混合精度(AMP)训练。具体而言,AMP在全精度中存储权重的主副本用于更新,而激活、梯度和权重则以FP16存储用于算术计算。与全精度训练相比,AMP能够实现更快的训练/推断速度,并在网络训练期间减少内存消耗。

3. Rematerialization and offloading(重计算和卸载)

重计算也称为检查点技术[Chen et al., 2016],是一种广泛使用的时空权衡技术,只在前向传递期间存储部分激活/权重,并在反向传递期间重新计算其余部分。

至于卸载,这是一种使用外部内存(如CPU内存)作为GPU内存的扩展,通过GPU和CPU之间的通信来增加训练期间的内存容量。模型状态以及激活可以被卸载到CPU,但最佳选择需要最小化与GPU之间的通信成本(即数据移动),减少CPU计算并最大化GPU内存节省。

一个代表性的工作是ZeRO-Offoad,它提供了针对使用Adam优化器的混合精度训练的最佳卸载策略。它将所有fp32模型状态和fp16梯度卸载到CPU内存,并在CPU上计算fp32参数更新。fp16参数保留在GPU上,前向和反向计算在GPU上进行。

4. Parameter-efficient tuning(参数效率调整)

作为普通全微调的强大替代方法,参数效率调整(PET)仅更新少量额外的参数,同时将预训练模型冻结,以显著减少存储负担,并且在不需要为每种情况存储单独的模型实例的情况下适用于动态部署场景。

一般的PET方法可以分为基于添加的方法和基于重参数化的方法。前者在预训练模型上附加额外的可训练参数并仅调整这些参数。

四、Hardware/Algorithm Co-design

1. 高效硬件加速器设计

除了计算和内存负担外,设计高效的硬件加速器可以加速DNN的训练和推理。具体而言,与中央处理单元(CPU)相比,图形处理单元(GPU)由于高度的并行性而更强大,特别适合执行矩阵乘法。对于专注于特定计算任务的应用,应用特定集成电路(ASIC)具有低功耗和高训练/推理速度的优势。

相关实践学习
在云上部署ChatGLM2-6B大模型(GPU版)
ChatGLM2-6B是由智谱AI及清华KEG实验室于2023年6月发布的中英双语对话开源大模型。通过本实验,可以学习如何配置AIGC开发环境,如何部署ChatGLM2-6B大模型。
目录
相关文章
|
机器学习/深度学习
大模型训练loss突刺原因和解决办法
【1月更文挑战第19天】大模型训练loss突刺原因和解决办法
1899 1
大模型训练loss突刺原因和解决办法
|
3月前
|
机器学习/深度学习 数据可视化 PyTorch
Flow Matching生成模型:从理论基础到Pytorch代码实现
本文将系统阐述Flow Matching的完整实现过程,包括数学理论推导、模型架构设计、训练流程构建以及速度场学习等关键组件。通过本文的学习,读者将掌握Flow Matching的核心原理,获得一个完整的PyTorch实现,并对生成模型在噪声调度和分数函数之外的发展方向有更深入的理解。
1290 0
Flow Matching生成模型:从理论基础到Pytorch代码实现
|
机器学习/深度学习 人工智能 自然语言处理
深度学习还不如浅层网络?RL教父Sutton持续反向传播算法登Nature
【9月更文挑战第24天】近年来,深度学习在人工智能领域取得巨大成功,但在连续学习任务中面临“损失可塑性”问题,尤其在深度强化学习中更为突出。加拿大阿尔伯塔大学的研究人员提出了一种名为“持续反向传播”的算法,通过选择性地重新初始化网络中的低效用单元,保持模型的可塑性。该算法通过评估每个连接和权重的贡献效用来决定是否重新初始化隐藏单元,并引入成熟度阈值保护新单元。实验表明,该算法能显著提升连续学习任务的表现,尤其在深度强化学习领域效果明显。然而,算法也存在计算复杂性和成熟度阈值设置等问题。
232 2
|
资源调度 关系型数据库 MySQL
docker制作compose
本文介绍了Docker Compose的基本使用,包括安装、创建`docker-compose.yml`文件定义服务,以及如何使用环境变量和卷来配置多容器应用的步骤。
568 1
docker制作compose
|
10月前
|
IDE 测试技术 开发工具
10个必备Python调试技巧:从pdb到单元测试的开发效率提升指南
在Python开发中,调试是提升效率的关键技能。本文总结了10个实用的调试方法,涵盖内置调试器pdb、breakpoint()函数、断言机制、logging模块、列表推导式优化、IPython调试、警告机制、IDE调试工具、inspect模块和单元测试框架的应用。通过这些技巧,开发者可以更高效地定位和解决问题,提高代码质量。
1010 8
10个必备Python调试技巧:从pdb到单元测试的开发效率提升指南
|
10月前
|
搜索推荐 数据挖掘 数据处理
《探索 Faiss:原理与应用解析》
在数据驱动的时代,高效处理和搜索海量数据至关重要。Faiss 是一个专为大规模相似性搜索和聚类设计的库,擅长处理高维向量数据,广泛应用于文本处理、图像识别等领域。本文深入解析 Faiss 的原理、使用方法及其在图像检索、文本相似性比较和推荐系统中的实际应用,帮助读者掌握这一强大工具,提升数据处理能力。
455 2
|
缓存 JSON Java
那些年,我们写过的无效单元测试
在这篇文章里,作者通过日常的单元测试实践,系统地总结出一套避免编写无效单元测试用例的方法和原则。
527 80
那些年,我们写过的无效单元测试
|
机器学习/深度学习 PyTorch 算法框架/工具
彻底告别微调噩梦:手把手教你击退灾难性遗忘,让模型记忆永不褪色的秘密武器!
【10月更文挑战第5天】深度学习中,模型微调虽能提升性能,但也常导致灾难性遗忘,即学习新任务时遗忘旧知识。本文介绍几种有效解决方案,重点讲解弹性权重巩固(EWC)方法,通过在损失函数中添加正则项来防止重要权重被更新,保护模型记忆。文中提供了基于PyTorch的代码示例,包括构建神经网络、计算Fisher信息矩阵和带EWC正则化的训练过程。此外,还介绍了其他缓解灾难性遗忘的方法,如LwF、在线记忆回放及多任务学习,以适应不同应用场景。
1289 8
|
12月前
|
机器学习/深度学习 人工智能 算法
[大语言模型-算法优化] 微调技术-LoRA算法原理及优化应用详解
[大语言模型-算法优化] 微调技术-LoRA算法原理及优化应用详解
1346 0
[大语言模型-算法优化] 微调技术-LoRA算法原理及优化应用详解
|
Python
Polars实践(3):阿里天池——淘宝用户购物行为分析
Polars实践(3):阿里天池——淘宝用户购物行为分析
213 0