Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型

本文涉及的产品
交互式建模 PAI-DSW,5000CU*H 3个月
简介: Meta 内部都在用的 FX 工具大起底:利用 Graph Transformation 优化 PyTorch 模型

image.png

PyTorch 中的 graph mode 在性能方面表示更为出色,本文介绍 Torch.FX 这个强大工具,可以捕捉和优化 PyTorch 程序 graph。

一、简介

PyTorch 支持两种执行模式:eager mode 和 graph mode。

eager mode 中,模型中的运算符在读取时会立即执行,它易于使用,对机器学习从业者更友好,因此被设置为默认的执行模式。

graph mode 中,运算符先被合成一个 graph,然后作为一个整体进行编译和执行,它的性能更高,因此在实际生产中大量使用。

具体来说,graph mode 支持算子融合,两个算子通过合并,可以降低或本地化内存读取以及内核启动总开销。

融合可以是横向 (horizontal) 的:采取应用于多个 operand 的单一操作(如 BatchNorm),并将这些 operand 合并到一个数组中。

融合也可以是纵向 (vertical) 的:将一个内核与另一个内核合并,后者需要使用第一个内核的输出(如 ReLU 后接卷积)。

Torch.FX(缩写为 FX)是一个公开可用的工具包,作为 PyTorch 软件包的一部分,支持 graph mode 的执行。它可以:

  1. 从 PyTorch 程序中获取 graph
  2. 允许开发者在获取的 graph 上编写 transformation**

Meta 内部先前已经在用 FX 来优化生产模型 (production model) 的训练吞吐量 (training throughput)。本文将通过介绍 Meta 开发的基于 FX 的优化,来展示利用图结构转换 (graph transformation) 优化 PyTorch 部署模型性能的方法

二、背景

embedding table 广泛存在于推荐系统中,本节将介绍 FX 和 embedding table 的背景知识。

2.1. FX

图 1 是一个简单示例,演示了如何用 FX 转换 PyTorch 程序,它包含三个步骤:

  • 从程序中获取 graph
  • 修改 graph(在本例中,我们用 GELU 代替 RELU)
  • 从修改后的 graph 中生成一个新程序

image.png

图1:在 PyTorch 模块中用 GELU 取代 RELU 的 FX

FX API 为检查和转换 PyTorch 程序 graph 还提供了许多其他功能。

2.2. embedding table

image.png

图2:批尺寸=1 的稀疏特征 embedding table 示意图

在推荐系统中,稀疏特征(例如,User ID,Story ID)由 embedding table 表示。

embedding table E 是一个 HxD 矩阵,其中 H 是哈希大小,D 是嵌入向量维度。E 的每一行都是一个浮点数向量。

feature hashing 的作用是将一个稀疏特征映射到 E的索引列表中,例如 [S1,S2,...,Sk],其中 0≤Si<H。它的输出值计算为 f(E[S1], E[S2],...,E[Sk]),其中 E[Si] 是 Si 行的向量,f 是池化函数,通常是 sum,average,max 三个函数之一。

为了充分利用 GPU,稀疏特征通常为批处理。批处理中的每个实体都有自己的索引列表。如果一个批次有 B 个实体,可以简单理解为一个表征有 B 个索引列表。

更为严谨的表示方法是将 B 个索引列表合并成一个索引列表,并添加一个索引长度的列表(该批中的每个实体都有一个长度 length)。

例如,如果一批包含 3 个实体,其索引列表如下:

  • Entity 1: indices = [10, 20]
  • Entity 2: indices = [5, 9, 77, 81]
  • Entity 3: indices = [15, 20, 45]

则完整批尺寸的 indice 和 length 将是:

  • Indices = [10, 20, 5, 9, 77, 81, 15, 20, 45]
  • Lengths = [2, 4, 3]

而整个 batch 的 embedding table 查询,输出为是一个 BxD 矩阵。

三、3 种 FX Transformation

PyTorch 更新了 3 个 FX transformation,以加速对 embedding table 的访问,本节将逐一介绍。

下文 3.1 关于将多个小输入张量结合成一个大张量的转换;3.2 关于将多个并行计算链融合成一个计算链的转换;3.3 关于将通信与计算重叠的转换。

3.1 结合输入稀疏特征

batch 中的每个输入稀疏特征,都可以表示为两个列表:一个索引列表和一个 B length 列表,其中 B 表示批尺寸。

在 PyTorch 中,这两个列表都可以以张量的形式存在。当 PyTorch 模型在 GPU 上运行时,embedding table 通常存储在 GPU 内存中(它更接近 GPU,读写带宽比 CPU 内存更高)。

需要使用输入稀疏特征时,两个张量都要先从 CPU 复制到 GPU。然而每个主机到设备的内存复制都需要启动内核,这对于实际的数据传输来说,会更加耗费时间。

如果一个模型使用了多个输入稀疏特征,这种复制可能成为性能瓶颈(例如,1000 个输入稀疏特征将需要从主机到设备复制 2000 个张量)。

一个减少主机到设备 memcpy 数量的优化方法,就是在多个输入稀疏特征发送到设备之前,先将其进行组合。

例如,给定以下三个输入特征:

  • Feature_A: indices = [106, 211, 7], lengths = [2, 1]
  • Feature_B: indices = [52, 498, 616, 870, 1013], lengths = [3, 2]
  • Feature_C: indices = [2011, 19, 351, 790], lengths = [1, 3]

组合后的形式为:

Features_A_B_C: indices = [106, 211, 7, 52, 498, 616, 870, 1013, 2011, 19, 351, 790], lengths = [2, 1, 3, 2, 1, 3]

所以不需要从主机到设备复制 3x2=6 个张量,只需要复制 2 个张量。

图 3(b) 描述了这种优化的实现,它包含两个组件:

  • CPU 端:输入 pipeline 被修改为将所有稀疏特征的 indices 组合成一个张量,所有 length 组合成另一个张量。然后将这两个张量复制到 GPU 上。
  • GPU 端:使用 FX,在模型 graph 中插入一个Permute_and_Split 算子,从合并的张量中恢复单个特征 indices 和 length 张量,并将其发送至下游的相应节点。

image.png

优化前:两个张量都要从 CPU 复制到 GPU

image.png

优化后:将输入稀疏特征进行组合

3.2 从访问 embedding table 开始的计算链横向融合

在一个生产模型中,每个 GPU 上有 10 个 embedding table 很常见。出于性能方面的考虑,对这些 table 的查询被分到一组,这样它们的输出就被串联在一个大张量中(见图 4(a)中的红色部分)。

为了对单个特征输出进行计算,使用 Split 算子将大张量分成 N 个小张量(其中 N 为特征的数量),然后将所需的计算应用于每个张量。

如图 4(a) 所示,应用于每个特征输出 O 的计算是Tanh(LayerNorm(O))。所有的计算结果都被串联成一个大的张量,然后传递给下游的算子(图 4(a) 中的 Op1)。

这里主要的 runtime cost 是 GPU 内核启动的开销。例如,图 4(a) 中的 GPU 内核的启动次数为 2*N+3(图中的每个椭圆都表示一个 GPU 内核)。这会影响性能,因为 LayerNorm 和 Tanh 在 GPU 上的执行时间,与它们的内核启动时间相比很短。

此外,Split 算子可能会创建一个额外的嵌入向量输出张量的副本,消耗额外的 GPU 内存。

用 FX 来实现一种叫做横向融合 (horizontal fusion) 的优化,可以大大减少 GPU 内核的启动次数(在这个例子中,优化后的 GPU 内核启动次数为 5,见图 4(b))。

使用 Add_middle_dim 算子代替显式 Split,将 shape 为 (B, NxD) 的 2D 嵌入张量重塑为 shape 为 (B, N, D) 的 3D 张量。接下来将一个单一的 LayerNorm 应用到它的最后一维。对 LayerNorm 的结果应用一个 Tanh。最后,用 Remove_middle_dim 算子将 Tanh 的结果恢复成 2D 张量。

由于 Add_middle_dim 和 Remove_middle_dim 只是重塑张量,并没有创建额外的副本,所以也可以减少 GPU 内存的消耗。

image.png

优化前:所有输出被串联到一个大张量中

image.png

进行横向融合优化后

3.3 计算与通信间的重叠 (overlap)

面向投产的推荐模型的训练,通常是在分布式 GPU 系统上完成的。由于每个 GPU 的设备内存容量不足以容纳模型中的所有 embedding table,因此需要将其分布在多个 GPU 上。

在训练步骤中,GPU 需要从其他 GPU 上的 embedding table 中读取/写入特征值。这被称为 all-to-all 通信,可能是影响性能的重要原因。

通过 FX 实现一个 transformation,可以将计算与 all-to-all 通信重叠。图 5(a) 显示了一个具备嵌入向量 table 访问 (EmbeddingAllToAll) 及其他算子的模型 graph 实例。如图 5(b) 所示,在没有任何优化的情况下,它们会在一个 GPU 流上顺序执行。

使用FX将 EmbeddingAllToAll 分成 EmbeddingAllToAll_Request和EmbeddingAllToAll_Wait,并在它们之间安排独立的算子。

image.png

图5:计算与通信的重叠

3.4 总结

image.png

表1:本节讨论的优化及解决的相应性能瓶颈

为了发现哪些模型会从这些 transformation 中受益,开发人员对 MAIProf 收集的运行在 Meta 数据中心的模型的性能数据进行分析。得出与 eager mode 相比,这些 transformation 在一组生产模型上实现了 2-3 倍的速度提升。

四、结语

从性能角度考量,PyTorch 中的 graph mode 比生产环境中使用的 eager mode 更受欢迎。FX 是一个强大的工具,可以捕捉和优化 PyTorch 程序 graph。本文展示了三种 FX transformation,用于优化 Meta 内部的生产推荐模型。

最后希望更多 PyTorch 开发者可以使用 graph transformation 来提升模型的性能。

—— 完 ——

相关实践学习
基于阿里云DeepGPU实例,用AI画唯美国风少女
本实验基于阿里云DeepGPU实例,使用aiacctorch加速stable-diffusion-webui,用AI画唯美国风少女,可提升性能至高至原性能的2.6倍。
相关文章
|
2月前
|
机器学习/深度学习 自然语言处理 PyTorch
【PyTorch实战演练】基于AlexNet的预训练模型介绍
【PyTorch实战演练】基于AlexNet的预训练模型介绍
82 0
|
1月前
|
机器学习/深度学习 关系型数据库 MySQL
大模型中常用的注意力机制GQA详解以及Pytorch代码实现
GQA是一种结合MQA和MHA优点的注意力机制,旨在保持MQA的速度并提供MHA的精度。它将查询头分成组,每组共享键和值。通过Pytorch和einops库,可以简洁实现这一概念。GQA在保持高效性的同时接近MHA的性能,是高负载系统优化的有力工具。相关论文和非官方Pytorch实现可进一步探究。
102 4
|
13天前
|
PyTorch 算法框架/工具 Python
【pytorch框架】对模型知识的基本了解
【pytorch框架】对模型知识的基本了解
|
23天前
|
机器学习/深度学习 算法 PyTorch
PyTorch模型优化与调优:正则化、批归一化等技巧
【4月更文挑战第18天】本文探讨了PyTorch中提升模型性能的优化技巧,包括正则化(L1/L2正则化、Dropout)、批归一化、学习率调整策略和模型架构优化。正则化防止过拟合,Dropout提高泛化能力;批归一化加速训练并提升性能;学习率调整策略动态优化训练效果;模型架构优化涉及网络结构和参数的调整。这些方法有助于实现更高效的深度学习模型。
|
23天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch与迁移学习:利用预训练模型提升性能
【4月更文挑战第18天】PyTorch支持迁移学习,助力提升深度学习性能。预训练模型(如ResNet、VGG)在大规模数据集(如ImageNet)训练后,可在新任务中加速训练,提高准确率。通过选择模型、加载预训练权重、修改结构和微调,可适应不同任务需求。迁移学习节省资源,但也需考虑源任务与目标任务的相似度及超参数选择。实践案例显示,预训练模型能有效提升小数据集上的图像分类任务性能。未来,迁移学习将继续在深度学习领域发挥重要作用。
|
2月前
|
PyTorch 算法框架/工具 Python
Pytorch构建网络模型时super(__class__, self).__init__()的作用
Pytorch构建网络模型时super(__class__, self).__init__()的作用
12 0
|
3月前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
2月前
|
机器学习/深度学习 算法 PyTorch
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
【PyTorch实战演练】深入剖析MTCNN(多任务级联卷积神经网络)并使用30行代码实现人脸识别
90 2
|
3月前
|
机器学习/深度学习 算法 PyTorch
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
pytorch实现手写数字识别 | MNIST数据集(全连接神经网络)
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习中卷积神经网络(CNN)的讲解及图像处理实战(超详细 附源码)
PyTorch深度学习中卷积神经网络(CNN)的讲解及图像处理实战(超详细 附源码)
130 0

热门文章

最新文章