Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型

本文涉及的产品
模型在线服务 PAI-EAS,A10/V100等 500元 1个月
模型训练 PAI-DLC,100CU*H 3个月
交互式建模 PAI-DSW,每月250计算时 3个月
简介: Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型

image.png

PyTorch 中的 graph mode 在性能方面表示更为出色,本文介绍 Torch.FX 这个强大工具,可以捕捉和优化 PyTorch 程序 graph。

一、简介

PyTorch 支持两种执行模式:eager mode 和 graph mode。

eager mode 中,模型中的运算符在读取时会立即执行,它易于使用,对机器学习从业者更友好,因此被设置为默认的执行模式。

graph mode 中,运算符先被合成一个 graph,然后作为一个整体进行编译和执行,它的性能更高,因此在实际生产中大量使用。

具体来说,graph mode 支持算子融合,两个算子通过合并,可以降低或本地化内存读取以及内核启动总开销。

融合可以是横向 (horizontal) 的:采取应用于多个 operand 的单一操作(如 BatchNorm),并将这些 operand 合并到一个数组中。

融合也可以是纵向 (vertical) 的:将一个内核与另一个内核合并,后者需要使用第一个内核的输出(如 ReLU 后接卷积)。

Torch.FX(缩写为 FX)是一个公开可用的工具包,作为 PyTorch 软件包的一部分,支持 graph mode 的执行。它可以:

  1. 从 PyTorch 程序中获取 graph
  2. 允许开发者在获取的 graph 上编写 transformation**

Meta 内部先前已经在用 FX 来优化生产模型 (production model) 的训练吞吐量 (training throughput)。本文将通过介绍 Meta 开发的基于 FX 的优化,来展示利用图结构转换 (graph transformation) 优化 PyTorch 部署模型性能的方法

二、背景

embedding table 广泛存在于推荐系统中,本节将介绍 FX 和 embedding table 的背景知识。

2.1. FX

图 1 是一个简单示例,演示了如何用 FX 转换 PyTorch 程序,它包含三个步骤:

  • 从程序中获取 graph
  • 修改 graph(在本例中,我们用 GELU 代替 RELU)
  • 从修改后的 graph 中生成一个新程序

image.png

图1:在 PyTorch 模块中用 GELU 取代 RELU 的 FX

FX API 为检查和转换 PyTorch 程序 graph 还提供了许多其他功能。

2.2. embedding table

image.png

图2:批尺寸=1 的稀疏特征 embedding table 示意图

在推荐系统中,稀疏特征(例如,User ID,Story ID)由 embedding table 表示。

embedding table E 是一个 HxD 矩阵,其中 H 是哈希大小,D 是嵌入向量维度。E 的每一行都是一个浮点数向量。

feature hashing 的作用是将一个稀疏特征映射到 E的索引列表中,例如 [S1,S2,...,Sk],其中 0≤Si<H。它的输出值计算为 f(E[S1], E[S2],...,E[Sk]),其中 E[Si] 是 Si 行的向量,f 是池化函数,通常是 sum,average,max 三个函数之一。

为了充分利用 GPU,稀疏特征通常为批处理。批处理中的每个实体都有自己的索引列表。如果一个批次有 B 个实体,可以简单理解为一个表征有 B 个索引列表。

更为严谨的表示方法是将 B 个索引列表合并成一个索引列表,并添加一个索引长度的列表(该批中的每个实体都有一个长度 length)。

例如,如果一批包含 3 个实体,其索引列表如下:

  • Entity 1: indices = [10, 20]
  • Entity 2: indices = [5, 9, 77, 81]
  • Entity 3: indices = [15, 20, 45]

则完整批尺寸的 indice 和 length 将是:

  • Indices = [10, 20, 5, 9, 77, 81, 15, 20, 45]
  • Lengths = [2, 4, 3]

而整个 batch 的 embedding table 查询,输出为是一个 BxD 矩阵。

三、3 种 FX Transformation

PyTorch 更新了 3 个 FX transformation,以加速对 embedding table 的访问,本节将逐一介绍。

下文 3.1 关于将多个小输入张量结合成一个大张量的转换;3.2 关于将多个并行计算链融合成一个计算链的转换;3.3 关于将通信与计算重叠的转换。

3.1 结合输入稀疏特征

batch 中的每个输入稀疏特征,都可以表示为两个列表:一个索引列表和一个 B length 列表,其中 B 表示批尺寸。

在 PyTorch 中,这两个列表都可以以张量的形式存在。当 PyTorch 模型在 GPU 上运行时,embedding table 通常存储在 GPU 内存中(它更接近 GPU,读写带宽比 CPU 内存更高)。

需要使用输入稀疏特征时,两个张量都要先从 CPU 复制到 GPU。然而每个主机到设备的内存复制都需要启动内核,这对于实际的数据传输来说,会更加耗费时间。

如果一个模型使用了多个输入稀疏特征,这种复制可能成为性能瓶颈(例如,1000 个输入稀疏特征将需要从主机到设备复制 2000 个张量)。

一个减少主机到设备 memcpy 数量的优化方法,就是在多个输入稀疏特征发送到设备之前,先将其进行组合。

例如,给定以下三个输入特征:

  • Feature_A: indices = [106, 211, 7], lengths = [2, 1]
  • Feature_B: indices = [52, 498, 616, 870, 1013], lengths = [3, 2]
  • Feature_C: indices = [2011, 19, 351, 790], lengths = [1, 3]

组合后的形式为:

Features_A_B_C: indices = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790], lengths = [2, 1, 3, 2, 1, 3]

所以不需要从主机到设备复制 3x2=6 个张量,只需要复制 2 个张量。

图 3(b) 描述了这种优化的实现,它包含两个组件:

  • CPU 端:输入 pipeline 被修改为将所有稀疏特征的 indices 组合成一个张量,所有 length 组合成另一个张量。然后将这两个张量复制到 GPU 上。
  • GPU 端:使用 FX,在模型 graph 中插入一个Permute_and_Split 算子,从合并的张量中恢复单个特征 indices 和 length 张量,并将其发送至下游的相应节点。

image.png

优化前:两个张量都要从 CPU 复制到 GPU

image.png

优化后:将输入稀疏特征进行组合

3.2 从访问 embedding table 开始的计算链横向融合

在一个生产模型中,每个 GPU 上有 10 个 embedding table 很常见。出于性能方面的考虑,对这些 table 的查询被分到一组,这样它们的输出就被串联在一个大张量中(见图 4(a)中的红色部分)。

为了对单个特征输出进行计算,使用 Split 算子将大张量分成 N 个小张量(其中 N 为特征的数量),然后将所需的计算应用于每个张量。

如图 4(a) 所示,应用于每个特征输出 O 的计算是Tanh(LayerNorm(O))。所有的计算结果都被串联成一个大的张量,然后传递给下游的算子(图 4(a) 中的 Op1)。

这里主要的 runtime cost 是 GPU 内核启动的开销。例如,图 4(a) 中的 GPU 内核的启动次数为 2*N+3(图中的每个椭圆都表示一个 GPU 内核)。这会影响性能,因为 LayerNorm 和 Tanh 在 GPU 上的执行时间,与它们的内核启动时间相比很短。

此外,Split 算子可能会创建一个额外的嵌入向量输出张量的副本,消耗额外的 GPU 内存。

用 FX 来实现一种叫做横向融合 (horizontal fusion) 的优化,可以大大减少 GPU 内核的启动次数(在这个例子中,优化后的 GPU 内核启动次数为 5,见图 4(b))。

使用 Add_middle_dim 算子代替显式 Split,将 shape 为 (B, NxD) 的 2D 嵌入张量重塑为 shape 为 (B, N, D) 的 3D 张量。接下来将一个单一的 LayerNorm 应用到它的最后一维。对 LayerNorm 的结果应用一个 Tanh。最后,用 Remove_middle_dim 算子将 Tanh 的结果恢复成 2D 张量。

由于 Add_middle_dim 和 Remove_middle_dim 只是重塑张量,并没有创建额外的副本,所以也可以减少 GPU 内存的消耗。

image.png

优化前:所有输出被串联到一个大张量中

image.png

进行横向融合优化后

3.3 计算与通信间的重叠 (overlap)

面向投产的推荐模型的训练,通常是在分布式 GPU 系统上完成的。由于每个 GPU 的设备内存容量不足以容纳模型中的所有 embedding table,因此需要将其分布在多个 GPU 上。

在训练步骤中,GPU 需要从其他 GPU 上的 embedding table 中读取/写入特征值。这被称为 all-to-all 通信,可能是影响性能的重要原因。

通过 FX 实现一个 transformation,可以将计算与 all-to-all 通信重叠。图 5(a) 显示了一个具备嵌入向量 table 访问 (EmbeddingAllToAll) 及其他算子的模型 graph 实例。如图 5(b) 所示,在没有任何优化的情况下,它们会在一个 GPU 流上顺序执行。

使用FX将 EmbeddingAllToAll 分成 EmbeddingAllToAll_Request和EmbeddingAllToAll_Wait,并在它们之间安排独立的算子。

image.png

图5:计算与通信的重叠

3.4 总结

image.png

表1:本节讨论的优化及解决的相应性能瓶颈

为了发现哪些模型会从这些 transformation 中受益,开发人员对 MAIProf 收集的运行在 Meta 数据中心的模型的性能数据进行分析。得出与 eager mode 相比,这些 transformation 在一组生产模型上实现了 2-3 倍的速度提升。

四、结语

从性能角度考量,PyTorch 中的 graph mode 比生产环境中使用的 eager mode 更受欢迎。FX 是一个强大的工具,可以捕捉和优化 PyTorch 程序 graph。本文展示了三种 FX transformation,用于优化 Meta 内部的生产推荐模型。

最后希望更多 PyTorch 开发者可以使用 graph transformation 来提升模型的性能。

—— 完 ——

相关实践学习
在云上部署ChatGLM2-6B大模型(GPU版)
ChatGLM2-6B是由智谱AI及清华KEG实验室于2023年6月发布的中英双语对话开源大模型。通过本实验,可以学习如何配置AIGC开发环境,如何部署ChatGLM2-6B大模型。
相关文章
|
18天前
|
机器学习/深度学习 数据采集 人工智能
PyTorch学习实战:AI从数学基础到模型优化全流程精解
本文系统讲解人工智能、机器学习与深度学习的层级关系,涵盖PyTorch环境配置、张量操作、数据预处理、神经网络基础及模型训练全流程,结合数学原理与代码实践,深入浅出地介绍激活函数、反向传播等核心概念,助力快速入门深度学习。
71 1
|
3月前
|
机器学习/深度学习 PyTorch 测试技术
从训练到推理:Intel Extension for PyTorch混合精度优化完整指南
PyTorch作为主流深度学习框架,凭借动态计算图和异构计算支持,广泛应用于视觉与自然语言处理。Intel Extension for PyTorch针对Intel硬件深度优化,尤其在GPU上通过自动混合精度(AMP)提升训练与推理性能。本文以ResNet-50在CIFAR-10上的实验为例,详解如何利用该扩展实现高效深度学习优化。
161 0
|
18天前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
55 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
2月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
146 9
|
2月前
|
机器学习/深度学习 算法 数据可视化
近端策略优化算法PPO的核心概念和PyTorch实现详解
本文深入解析了近端策略优化(PPO)算法的核心原理,并基于PyTorch框架实现了完整的强化学习训练流程。通过Lunar Lander环境展示了算法的全过程,涵盖环境交互、优势函数计算、策略更新等关键模块。内容理论与实践结合,适合希望掌握PPO算法及其实现的读者。
292 2
近端策略优化算法PPO的核心概念和PyTorch实现详解
|
3月前
|
机器学习/深度学习 数据可视化 PyTorch
Flow Matching生成模型:从理论基础到Pytorch代码实现
本文将系统阐述Flow Matching的完整实现过程,包括数学理论推导、模型架构设计、训练流程构建以及速度场学习等关键组件。通过本文的学习,读者将掌握Flow Matching的核心原理,获得一个完整的PyTorch实现,并对生成模型在噪声调度和分数函数之外的发展方向有更深入的理解。
1127 0
Flow Matching生成模型:从理论基础到Pytorch代码实现
|
5月前
|
机器学习/深度学习 PyTorch API
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
本文深入探讨神经网络模型量化技术,重点讲解训练后量化(PTQ)与量化感知训练(QAT)两种主流方法。PTQ通过校准数据集确定量化参数,快速实现模型压缩,但精度损失较大;QAT在训练中引入伪量化操作,使模型适应低精度环境,显著提升量化后性能。文章结合PyTorch实现细节,介绍Eager模式、FX图模式及PyTorch 2导出量化等工具,并分享大语言模型Int4/Int8混合精度实践。最后总结量化最佳策略,包括逐通道量化、混合精度设置及目标硬件适配,助力高效部署深度学习模型。
683 21
PyTorch量化感知训练技术:模型压缩与高精度边缘部署实践
|
7月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
470 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
4月前
|
机器学习/深度学习 存储 PyTorch
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统
本文通过使用 Kaggle 数据集训练情感分析模型的实例,详细演示了如何将 PyTorch 与 MLFlow 进行深度集成,实现完整的实验跟踪、模型记录和结果可复现性管理。文章将系统性地介绍训练代码的核心组件,展示指标和工件的记录方法,并提供 MLFlow UI 的详细界面截图。
154 2
PyTorch + MLFlow 实战:从零构建可追踪的深度学习模型训练系统

热门文章

最新文章

推荐镜像

更多