降龙十八掌:这套优化transformer内存占用的组合技值得收藏(1)

简介: 降龙十八掌:这套优化transformer内存占用的组合技值得收藏


很多时候,内存限制会阻碍 ViT 以及 LLM 的训练,这篇文章介绍了 9 种减少内存消耗峰值的方法。难能可贵的是,这几种方法可以同时使用,就好像降龙十八掌中最后一掌,正是将前几张组合起来就能打出最强大的效果。


峰值内存消耗是训练深度学习模型(如视觉 Transformer 和 LLM)时的常见瓶颈。本文提供了一系列可以在不牺牲建模性能和预测精度的情况下,将 PyTorch 中的内存消耗降低到约 1/20 的技术。

以 PyTorch 的 Torchvision 库中的视觉 transformer 为基础,本文作者编写了大约 100 行代码的训练脚本,并且所有代码示例都可以在 GitHub 上找到。

以下是本文将要介绍的技术名称:

  • 微调 vision transformer
  • 自动混合精度训练
  • 低精度训练
  • Reduced Batch Size 训练
  • 梯度积累与 Microbatches
  • 选择更精简的优化器
  • 在目标设备上实例化模型
  • 分布式训练与张量共享
  • 参数卸载
  • 以上九种方法结合起来,就形成了一种可以用于 llm 的综合方法,也可以称之为第十种方法。


这些方法是互相解耦的,可以将它们叠加在一起使用。

本文在实验中使用的 ViT 为 ViT-L-16 模型。在依次将上述方法添加后,研究者将训练 BigBird-Roberta LLM 来执行文本分类任务。这些技术使得在消费类硬件上训练这样的模型成为可能。

微调 vision transformer

为了简化实验中的 PyTorch 代码,本文使用了开源库 ——Fabric,十几行代码就能应用各种先进的 PyTorc 技术(自动混合精度训练、多 GPU 训练、张量分片等)。

原生 PyTorch 代码和修改后的使用 Fabric 的代码之间的区别很微妙,只有较小的修改,如下面的代码所示:


如上所述,改动虽然不大,但是可以方便的使用 PyTorch 中的高级功能,而无需重新构造任何现有代码。

总结上图,将普通 PyTorch 代码转换为 PyTorch+Fabric 的主要 3 个步骤可以归纳如下:


  1. 导入 Fabric 并实例化 Fabric 对象。
  2. 使用 Fabric 设置模型、优化器和数据加载程序。
  3. 调用 fabric.backward () 构造损失函数,而不是通常使用的 loss.backward ()


使用普通 PyTorch 和 PyTorch with Fabric 的性能和内存消耗几乎完全相同:

Plain PyTorch (01_pytorch-vit.py):




Time elapsed 17.94 minMemory used: 26.79 GBTest accuracy 95.85%


PyTorch with Fabric (01-2_pytorch-fabric.py)





Time elapsed 17.88 minMemory used: 26.84 GBTest accuracy 96.06%


也可以将下面的代码


model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1)


替换为


model = vit_l_16(weights=None)


替换后,将不再是微调,而是从头开始训练相同的 ViT 架构,预测准确率会从 96% 以上下降到约 60%:


自动混合精度

上一节使用 Fabric 修改了 PyTorch 代码,在此基础上,使用混合精度和分布式训练,也只需更改一行代码。

应用混合精度训练

应用混合精度训练,只需一个小的修改,将下面这行代码


fabric = Fabric(accelerator="cuda", devices=1)


替换为


fabric = Fabric(accelerator="cuda", devices=1, precision="16-mixed")


之后,在不牺牲预测精度的情况下,内存消耗从 26.84GB 减少到 18.21GB,如下所示:

01-2_pytoch-fabric.py 和 02_mixed-precision.py 的结果对比

此外,混合精确训练不仅减少了内存使用,还将运行时间减少了 6 倍(从 17.88 分钟减少到 3.45 分钟),这可以说是意外收获。

什么是混合精度训练?

混合精度训练同时使用 16 位和 32 位精度,以确保不损失精度。16 位表示的梯度计算比 32 位格式快得多,并且节省了大量的内存。这种策略是有益的,尤其是当受到内存或计算限制时。

之所以被称为「混合」而不是「低」精度训练的原因是,并不会将所有参数和操作都转移成 16 位浮点数。实际上,在训练期间会在 32 位和 16 位运算之间切换。

如下图所示,混合精度训练可以分解为:将权重转换为较低精度(如 FP16)以实现更快的计算、计算梯度、将梯度转换回较高精度(FP32)以实现数值稳定性,以及用缩放的梯度更新原始权重等几个步骤。


这种方法在保证训练有效的前提下,还能保持神经网络的准确性和稳定性。

感兴趣的读者还可以在本文作者的另一篇文章:《使用混合精度技术加速大型语言模型》中获得更多底层概念。

文章地址:https://lightning.ai/pages/community/tutorial/accelerating-large-language-models-with-mixed-precision-techniques/

低精度训练

还可以更进一步,尝试以「完全」较低的 16 位精度运行,而不是混合精度。

将下面这行代码


fabric = Fabric(accelerator="cuda", precision="16-mixed")


替换为


fabric = Fabric(accelerator="cuda", precision="16-true")


但需要注意的是,这样会在训练中产生 NaN 值:





Epoch: 0001/0001 | Batch 0000/0703 | Loss: 2.4105Epoch: 0001/0001 | Batch 0300/0703 | Loss: nanEpoch: 0001/0001 | Batch 0600/0703 | Loss: nan...


这是因为常规的 16 位浮点只能表示 - 65504 和 65504 之间的数字:





In [1]: import torch
In [2]: torch.finfo(torch.float16)Out[2]: finfo(resolution=0.001, min=-65504, max=65504, eps=0.000976562, smallest_normal=6.10352e-05, tiny=6.10352e-05, dtype=float16)


因此,为了避免 NaN 问题,可以将参数修改为「bf16 true」:


fabric = Fabric(accelerator="cuda", precision="bf16-true")


可以将内存消耗进一步降低到 13.82 GB(同样,在不牺牲准确性的情况下):

将 03_bfloat16.py 与之前的代码的结果进行比较

什么是 Bfloat16?

「bf16 mixed」中的「bf16」代表 Brain Floating Point(bfloat16)。谷歌为机器学习和深度学习应用程序开发了这种格式,特别是在其张量处理单元(TPU)中。与传统 float16 格式相比,Bfloat16 以降低精度为代价扩展了动态范围。


扩展的动态范围有助于 bfloat16 表示非常大和非常小的数字,使其更适合可能遇到广泛值的深度学习应用。然而,较低的精度可能会影响某些计算的准确性,或在某些情况下导致舍入误差。但在大多数深度学习应用中,这种精度的降低对建模性能的影响微乎其微。

虽然 bfloat16 最初是为 TPU 开发的,但这种格式从 A100 Tensor Core GPU 开始,也得到了其之后的 NVIDIA GPU 的支持。

以下代码可以检查 GPU 是否支持 bfloat16:




>>> import torch>>> torch.cuda.is_bf16_supported()True


减少批大小

减少批大小通常是减少内存消耗的一个有效方法。然而,它有时会导致较差的预测性能,因为这样要改变训练动态。

无论哪种方式,需要探讨减少批量大小对结果有何影响。事实证明,可以在不牺牲性能的情况下将批大小降低到 16,从而将内存消耗降至 5.69 GB:

将 04_lower-batchsize.py 与以前的代码进行比较。

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
1月前
|
存储 缓存 监控
|
12天前
|
缓存 算法 Java
本文聚焦于Java内存管理与调优,介绍Java内存模型、内存泄漏检测与预防、高效字符串拼接、数据结构优化及垃圾回收机制
在现代软件开发中,性能优化至关重要。本文聚焦于Java内存管理与调优,介绍Java内存模型、内存泄漏检测与预防、高效字符串拼接、数据结构优化及垃圾回收机制。通过调整垃圾回收器参数、优化堆大小与布局、使用对象池和缓存技术,开发者可显著提升应用性能和稳定性。
33 6
|
12天前
|
监控 安全 程序员
如何使用内存池池来优化应用程序性能
如何使用内存池池来优化应用程序性能
|
12天前
|
存储 监控 Java
深入理解计算机内存管理:优化策略与实践
深入理解计算机内存管理:优化策略与实践
|
23天前
|
存储 JavaScript 前端开发
如何优化代码以避免闭包引起的内存泄露
本文介绍了闭包引起内存泄露的原因,并提供了几种优化代码的策略,帮助开发者有效避免内存泄露问题,提升应用性能。
|
24天前
|
并行计算 算法 IDE
【灵码助力Cuda算法分析】分析共享内存的矩阵乘法优化
本文介绍了如何利用通义灵码在Visual Studio 2022中对基于CUDA的共享内存矩阵乘法优化代码进行深入分析。文章从整体程序结构入手,逐步深入到线程调度、矩阵分块、循环展开等关键细节,最后通过带入具体值的方式进一步解析复杂循环逻辑,展示了通义灵码在辅助理解和优化CUDA编程中的强大功能。
|
1月前
|
存储 弹性计算 算法
前端大模型应用笔记(四):如何在资源受限例如1核和1G内存的端侧或ECS上运行一个合适的向量存储库及如何优化
本文探讨了在资源受限的嵌入式设备(如1核处理器和1GB内存)上实现高效向量存储和检索的方法,旨在支持端侧大模型应用。文章分析了Annoy、HNSWLib、NMSLib、FLANN、VP-Trees和Lshbox等向量存储库的特点与适用场景,推荐Annoy作为多数情况下的首选方案,并提出了数据预处理、索引优化、查询优化等策略以提升性能。通过这些方法,即使在资源受限的环境中也能实现高效的向量检索。
|
3月前
|
存储 编译器 C语言
【C语言篇】数据在内存中的存储(超详细)
浮点数就采⽤下⾯的规则表⽰,即指数E的真实值加上127(或1023),再将有效数字M去掉整数部分的1。
377 0
|
25天前
|
存储 C语言
数据在内存中的存储方式
本文介绍了计算机中整数和浮点数的存储方式,包括整数的原码、反码、补码,以及浮点数的IEEE754标准存储格式。同时,探讨了大小端字节序的概念及其判断方法,通过实例代码展示了这些概念的实际应用。
53 1
|
29天前
|
存储
共用体在内存中如何存储数据
共用体(Union)在内存中为所有成员分配同一段内存空间,大小等于最大成员所需的空间。这意味着所有成员共享同一块内存,但同一时间只能存储其中一个成员的数据,无法同时保存多个成员的值。