FlashAttention2原理解析以及面向AIGC的加速实践

简介: FlashAttention2原理解析以及面向AIGC的加速实践



FlashAttention-2提出后,便得到了大量关注。本文将具体讲述FlashAttention-2的前世今生,包括FlashAttention1&2的原理解析、加速效果比较以及面向AIGC的加速实践,在这里将相关内容与大家分享~


引言

将 Transformers 扩展到更长的序列长度一直是过去几年的一个热点问题,这将有助于提高语言建模和高分辨率图像理解的能力,也有利于音频和视频生成方面的新应用场景研发。Attention层是扩展到更长序列的主要瓶颈,因为它的运行时间和内存占用是序列长度的二次方。使用近似计算的Attention方法,可以通过减少FLOP计算次数、甚至于牺牲模型质量来降低计算复杂性,但通常无法实现大比例的加速。

由斯坦福大学提出的FlashAttention方法,让使用更长sequence计算Attention成为可能,并且通过线性级别的增长来节省内存以及加速计算。因为FlashAttention没有进行近似计算,所以也没有精度损失。然而,FlashAttention的实际速度仍然和理论上的运算速度差距较大,仅达到理论最大 FLOPs/s 的 25-40%。效率低下的原因主要是不同线程块和warp之间的工作分区不理想,导致低占用率或不必要的共享内存读/写。为此,2023年7月,论文作者进一步提出了FlashAttention-2,实现了Attention计算速度的大幅度提升。


FlashAttention

 主要内容


FlashAttention主要关注IO-aware,进一步优化GPU显存的读写效率。这是一种 IO 感知的精确Attention算法,它使用tiling(这里可以理解为分块)来减少 GPU 高带宽内存 (HBM) 和 GPU 片上 SRAM 之间的内存读/写次数。这里的HBM可以理解为显存,SRAM可以理解为cache。通过测试IO复杂性,相比标准 Attention,FlashAttention需要更少的 HBM 访问,并且对于不同的SRAM 大小来说都是有效的。除此以外,FlashAttention还可以扩展到block-sparse attention,产生比任何现有近似注意力方法更快的近似注意力算法。FlashAttention与 MLPerf 1.1 训练速度相比,对于BERT-large(序列长度 512)实现端到端wall-clock加速15%,对于GPT-2(序列长度 1K)加速 3 倍。FlashAttention 和block-sparse FlashAttention 可在 Transformers 中实现更长的上下文,从而产生更高质量的模型,GPT-2 上的困惑度提升0.7,长文档分类的test结果提高 6.4 个点。

 主要操作



背景知识:

上图的左图,表示存储结构,可以简单理解为:SRAM表示缓存,HBM表示显存,DRAM表示内存。


  • tiling


在不访问整个输入的情况下优化attention计算,并减少相关计算量。重构attention计算,将输入分割成块,并对分块进行多次传递,从而逐步执行attention计算(该步骤称为tiling)。



如上图所示,FlashAttention 使用tiling来防止在相对较慢的 GPU显存上实现大型 𝑁 × 𝑁 注意力矩阵(虚线框)计算。在外部循环(红色箭头)中,FlashAttention 循环遍历 K 和 V 矩阵块,并将它们加载到快速片上 SRAM。在每个块中,FlashAttention 循环遍历 Q 矩阵块(蓝色箭头),将它们加载到 SRAM,并将注意力计算的输出写回 HBM。



将输入Q、K、V矩阵分成很多块,将它们从较慢的HBM加载到较快的SRAM,然后在SRAM计算关于这些块的注意力输出。对每个块的计算结果缩放之后进行add操作,则得到正确的结果,具体伪代码如图:


  • recomputing


FlashAttention不专门存储用于后向计算的大型中间计算结果。在SRAM中存储前向计算中的 softmax 归一化因子,以便在后向传递计算梯度的时候快速得到中间结果,这比从 HBM 读取中间计算结果的标准方法更快。FlashAttention不存储前向计算中𝑂(𝑁2)复杂度的中间值,但是后向传递通常需要矩阵 S, P ∈ R𝑁 ×𝑁 来计算相对于 Q、K、V 的梯度。通过存储输出 O 和 softmax 归一化统计量 (𝑚, ℓ),则可以在 SRAM 中的 Q、K、V 块的后向计算中快速重新计算注意力矩阵 S 和 P。这可以看作是checkpoint的一种形式。如下图所示,FlashAttention由于tiling分块操作和recomputing操作,增加了一些计算次数。但是还是通过使用SRAM减少了显存占用,通过减少hbm访问次数加快了attention计算。

 Block-Sparse FlashAttention


论文还提出了Block-sparse FlashAttention,其IO复杂度比FlashAttention小,与稀疏度成正比。非0矩阵越少(即0矩阵越多),稀疏化来压缩数据的空间就越大,block-sparse加速就越明显。上图中,IO复杂度与稀疏性成正比,随着稀疏性的增加(非0矩阵增加),Block-sparse FlashAttention的运行时间成比例地提高。

 小结


总的来说,FlashAttention有如下优点:

  1. hbm访问次数降低,所以计算更快
  2. 在sram中计算attention,并对于后向计算提前保留中间结果,所以显存占用更少
  3. 可以使用更长的sequence,使得模型训练效果更好
  4. 对于attention计算,加速明显。如果加上稀疏化处理,速度会更快。


FlashAttention-2


 主要内容


FlashAttention 的整体速度仍然和单独进行矩阵乘法 (GEMM) 的运算速度差距较大,仅达到理论最大 FLOPs/s 的 25-40%。作者观察到效率低下的原因是不同线程块和warp之间的工作分区不理想,导致低占用率或不必要的共享内存读/写。最新提出 FlashAttention-2,通过更好的工作分区来解决这些问题,主要包含的操作:1.调整算法以减少非矩阵乘运算的计算次数。2.跨不同线程块进行并行化注意力计算。3.在每个线程块内, 在 warp 之间优化工作分配以减少共享内存的通信。与FlashAttention 相比,FlashAttention-2速度提高了约 2 倍,达到 A100 上理论最大 FLOPs/s 的 50-73%,接近 GEMM 操作的效率。根据经验验证,当使用端到端来训练 GPT 式模型时,FlashAttention-2 的训练速度高达每 A100 GPU 225 TFLOPs/s(模型 FLOPs 利用率为 72%)。不同设置(有或没有causal mask、不同头部尺寸)的测试表明,FlashAttention-2 比 FlashAttention 实现了约 2 倍的加速,在前向传递中达到理论最大吞吐量的 73%,在后向传递中达到理论最大吞吐量的 63%。

 主要操作


  • 减少非矩阵运算


调整算法以减少非 matmul(矩阵乘法) FLOP 的数量,同时不改变输出。虽然非 matmul FLOP 只占总 FLOP 的一小部分,但它们由于 GPU 具有专门的矩阵乘法单元,非矩阵乘法的运算需要更长的时间来执行,矩阵乘法吞吐量可以比非矩阵乘法吞吐量高出16倍。因此,减少非 matmul FLOP 并尽可能多的进行 matmul FLOP 非常重要。

背景知识: 吞吐量是指单位时间内完成的任务数量或数据处理量。在这个上下文中,吞吐量指的是执行矩阵乘法操作时的性能表现,以及执行其他非矩阵乘法操作时的性能表现。这句话的意思是,执行矩阵乘法操作时,系统能够以每单位时间处理更多的任务或数据,其数量可以高达非矩阵乘法操作时的16倍。这表明矩阵乘法操作在性能上比其他操作更加高效。


Forward pass:优化qkv的softmax计算中非矩阵运算
Backward pass:FlashAttention-2的后向传递与FlashAttention几乎相同,主要区别在于需要进行梯度计算与更新。这里做了一个小调整,只使用求和结果𝐿,而不是 softmax 中的行式最大值和行式指数和。


  • 增加并行比例


除了batchsize维度和head数目维度,还在序列长度维度上对前向传播和反向传播进行并行化处理,提高并行性。在序列较长的情况下,提高GPU资源的占用率。FlashAttention对于batchsize和head数目进行并行化处理,FlashAttention2基于序列长度进行并行化。当批量大小和head数量较小时,序列长度上增加的并行性有助于提高占用率(正在使用的 GPU 资源的比例),从而在这种情况下实现加速。


Forward pass:对批量维度和头数维度进行并行化,如 FlashAttention 中所做的那样。对于外循环(在序列长度上),将它们调度到不需要彼此通信的不同线程块上,每个工作线程负责关注矩阵的一行block块。外循环每次处理一行block,内循环每次处理这一行中的一列block,这和FlashAttention处理方式是不同的。



Backward pass:不同列块之间唯一共享的计算是算法 2 中更新的dQ,其中我们需要将 dQ从 HBM 加载到 SRAM,然后在片上通过 dQ更新,并写回 HBM。我们使用原子添加在不同线程块之间进行通信以更新 dQ。我们也在序列长度维度上进行并行化,并为后向传递的每一列block块安排 1 个工作线程(和前向传递是反过来的)。


  • 在warp上优化工作划分


在一个注意力计算的block内,在一个thread block的不同warp之间优化工作划分,以减少通信和共享内存的读/写。



在每个线程块内,我们也必须决定如何在不同的 warp 之间划分工作。我们通常每个线程块使用 4 或 8 个 warp,分区如上图所示。



Forward pass:对于每个块,FlashAttention 将 K 和 V 分割到 4 个 warp 上,同时保持 Q 可被所有 warp 访问。每个warp相乘得到 QK⊤ 的slice,然后它们需要与 V 的slice相乘并进行通信以将结果相加。这称为“split-K”方案。然而,这是低效的,因为所有 warp 都需要将其中间结果写入共享内存,进行同步,然后将中间结果相加这些共享内存读/写会减慢 FlashAttention 中的前向传播速度。在 FlashAttention-2 中,我们将 Q 分成 4 个经线,同时保持所有经线均可访问 K 和 V。在每个扭曲执行矩阵乘法以获得 QK⊤ 切片后,它们只需与共享的 V 切片相乘即可获得相应的输出切片。warp 之间不需要通信。共享内存读/写的减少可以提高速度


背景知识: warp:由多个thread组成,是编程层面的概念。 flash1:k和v被分为4个不同的warp,q和k计算、再和v计算,每一次计算的中间结果都要写入共享内存,并在之后被读取。这样就增加了共享内存的读写次数、拖慢了速度。 flash2:将q分为4个不同的warp,然后计算qk、计算v。但是这里k和v不需要通信,所以计算v的时候,不需要新的内存读写。这样就减少了读写次数、加快了程序。


Backward pass:对于后向传递,我们选择对warp进行分区以避免“split-K”方案,从而减少共享内存的读/写次数,并再次提高速度。由于所有不同输入和梯度 Q、K、V、O、dO、dQ、dK、dV 之间的依赖性更加复杂,它需要一些同步操作。

 小结


FlashAttention-2可以加速attention计算。测量FlashAttention-2 在不同序列长度上的运行时间,并与 PyTorch、FlashAttention 和 Triton 中的 FlashAttention 中的标准实现进行比较。FlashAttention-2 比 FlashAttention 快 1.7-3.0 倍,比 Triton 中的 FlashAttention 快 1.3-2.5 倍,比标准注意力实现快 3-10 倍。FlashAttention-2可以加速端到端训练。当使用端到端在 2k 或 8k 序列长度上训练大小为 1.3B 和 2.7B 的 GPT 型模型时,FlashAttention-2 与 FlashAttention 相比可实现高达 1.3 倍的加速,与基线相比可实现 2.8 倍的加速 没有FlashAttention。每个 A100 GPU 的 FlashAttention-2 速度高达 225 TFLOPs/s(模型 FLOPs 利用率为 72%)。FlashAttention-2 比 FlashAttention 快 2 倍,可以用之前训练 8k 上下文模型的时间,来训练具有 16k 更长上下文的模型。使用更长的context训练模型,可以更好理解长篇书籍和报告、高分辨率图像、音频和视频。

FlashAttention-2加速实践


 时间与显存的优化效果


对于qkv计算,比较FlashAttention2与custom pytorch、xformers(FlashAttention1)的时间与显存消耗。如果只考虑QKV计算,flash attention2耗时是xformers(flash attention1)的一半,内存节省也更多一些。



flash attention2耗时是xformers(flash attention1)的一半,内存节省也更多一些
test 0 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000754, peak memory: 113 MB
flash attention time: 0.000103, speedup: 7.29; peak memory: 45 MB, save: 60%
xformers time: 0.000255, speedup: 2.95; peak memory: 63 MB, save: 44%
test 1 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000703, peak memory: 131 MB
flash attention time: 0.000106, speedup: 6.63; peak memory: 57 MB, save: 56%
xformers time: 0.000252, speedup: 2.80; peak memory: 70 MB, save: 46%
test 2 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000721, peak memory: 131 MB
flash attention time: 0.000106, speedup: 6.78; peak memory: 57 MB, save: 56%
xformers time: 0.000263, speedup: 2.74; peak memory: 70 MB, save: 46%
test 3 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000704, peak memory: 131 MB
flash attention time: 0.000105, speedup: 6.71; peak memory: 57 MB, save: 56%
xformers time: 0.000249, speedup: 2.82; peak memory: 70 MB, save: 46%
test 4 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000700, peak memory: 131 MB
flash attention time: 0.000110, speedup: 6.35; peak memory: 57 MB, save: 56%
xformers time: 0.000254, speedup: 2.75; peak memory: 70 MB, save: 46%
test 5 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000766, peak memory: 131 MB
flash attention time: 0.000106, speedup: 7.25; peak memory: 57 MB, save: 56%
xformers time: 0.000252, speedup: 3.04; peak memory: 70 MB, save: 46%
test 6 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000684, peak memory: 131 MB
flash attention time: 0.000101, speedup: 6.77; peak memory: 57 MB, save: 56%
xformers time: 0.000268, speedup: 2.56; peak memory: 70 MB, save: 46%
test 7 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000717, peak memory: 131 MB
flash attention time: 0.000110, speedup: 6.52; peak memory: 57 MB, save: 56%
xformers time: 0.000254, speedup: 2.82; peak memory: 70 MB, save: 46%
test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000700, peak memory: 131 MB
flash attention time: 0.000100, speedup: 6.98; peak memory: 57 MB, save: 56%
xformers time: 0.000253, speedup: 2.77; peak memory: 70 MB, save: 46%
test 8 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000700, peak memory: 131 MB
flash attention time: 0.000100, speedup: 6.98; peak memory: 57 MB, save: 56%
xformers time: 0.000253, speedup: 2.77; peak memory: 70 MB, save: 46%
test 9 >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
shape: (9, 351, 16, 64), dtype: torch.bfloat16, causal: True
custom pytorch time: 0.000721, peak memory: 131 MB
flash attention time: 0.000102, speedup: 7.10; peak memory: 57 MB, save: 56%
xformers time: 0.000251, speedup: 2.87; peak memory: 70 MB, save: 46%

 精度损失比较


计算FlashAttention2对于注意力机制的精度损失,与pytorch的计算精度进行对比。



绝大部分用例都可以通过测试,并且符合要求:



dQ Pytorch mean diff: 0.000698089599609375
dK Pytorch mean diff: 0.0005950927734375
dV Pytorch mean diff: 0.000537872314453125
.Actual dropout fraction: 0.17163611948490143
Output max diff: 0.001953125
Output mean diff: 2.9206275939941406e-05
Pytorch max diff: 0.0029296875
Pytorch mean diff: 8.106231689453125e-05
Attention max diff: 0.000244140625
Attention Pytorch max diff: 0.000732421875
dQ max diff: 0.0025577545166015625
dK max diff: 0.00390625
dV max diff: 0.0078125
dQ mean diff: 3.904104232788086e-05
dK mean diff: 0.0001360177993774414
dV mean diff: 0.0001475811004638672
dQ Pytorch max diff: 0.00390625
dK Pytorch max diff: 0.004150390625
dV Pytorch max diff: 0.0078125
dQ Pytorch mean diff: 8.702278137207031e-05
dK Pytorch mean diff: 0.00025916099548339844
dV Pytorch mean diff: 0.0002474784851074219
.Actual dropout fraction: 0.17163611948490143
Output max diff: 0.015625
Output mean diff: 0.0002346038818359375
Pytorch max diff: 0.015625
Pytorch mean diff: 0.00064849853515625
Attention max diff: 0.001953125
Attention Pytorch max diff: 0.00390625
dQ max diff: 0.01953125
dK max diff: 0.033203125
dV max diff: 0.0625
dQ mean diff: 0.0003108978271484375
dK mean diff: 0.00109100341796875
dV mean diff: 0.0011749267578125
dQ Pytorch max diff: 0.01806640625
dK Pytorch max diff: 0.0390625
dV Pytorch max diff: 0.0625
dQ Pytorch mean diff: 0.00069427490234375
dK Pytorch mean diff: 0.0020751953125
dV Pytorch mean diff: 0.001953125
...


少数用例不符合要求,有一定的精度损失:


FAILED tests/test_flash_attn.py::test_flash_attn_race_condition[0.0-128-128-False-dtype0] - assert False
FAILED tests/test_flash_attn.py::test_flash_attn_race_condition[0.0-128-128-True-dtype0] - assert False
FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-False-dtype0] - AssertionError: assert 236.75 <= (2 * 0.0009765625)
FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-False-dtype1] - AssertionError: assert 22144.0 <= (2 * 0.0078125)
FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-True-dtype0] - AssertionError: assert 2.724609375 <= (2 * 0.001953125)
FAILED tests/test_flash_attn.py::test_flash_attn_bwd_transpose[128-128-True-dtype1] - AssertionError: assert 95.5 <= (2 * 0.015625)


FlashAttention2与参考方法的输出和梯度相比,误差很小并在可控范围内。对于不同的head dimensions, input dtype, sequence length, causal / non-causal,FlashAttention2的最大数值误差最多是 Pytorch的baseline中的数值误差的两倍。对于前向计算和后向计算。前向计算,是确定性的,每次测试结果可以复现。后向计算,非确定性(没有bit级别的确定性),每次结果可能有略微的不一样(比如输入不变,seed改变)。如果只做推理,则只涉及前向计算,所以计算是确定性的。
FlashAttention-2加速AIGC


 环境信息


NVIDIA A10, CUDA Version: 11.4, webui-1.5.1, eas推理平台


 加速效果


xformers(flash1):


文生图(512*512)(batchsize=1) 文生图(512*512)(batchsize=4)
unet耗时(s)

1

1

1

1

4

4

4

4

unet耗时(it/s)

(step = 20)

11.11it/s

11.27it/s

11.27it/s

11.27it/s

4.33it/s

4.33it/s

4.33it/s

4.33it/s


xformers(flash2):


文生图(512*512) 文生图(512*512)(batchsize=4)
unet耗时(s)

1

1

1

1

4

4

4

4

unet耗时(it/s)

(step = 20)

11.13it/s

11.75it/s

11.46it/s

11.92it/s

4.69it/s

4.69it/s

4.69it/s

4.68it/s


相对于xformers(flash1),xformers(flash2)提速:


unet过程提速
文生图加速(一次生成1图) (11.57-11.23)/11.23=3%
文生图加速(一次生成4图) (4.69-4.33)/4.33=8.3%


 精度比较


xformers(flash1)

文生图(512*512)_ouput1 文生图(512*512)_ouput2


xformers(flash2)

文生图(512*512)_ouput1 文生图(512*512)_ouput2


使用不同的加速方法,AIGC生成图像,均符合预期,无精度损失。


注:这里未固定seed,所以图像会有变化,但是生成效果符合预期。


 AIGC加速分析


使用flash_attention2,对sd加速,相比flash_attention1,加速比例并不高,或者说无法达到论文中那么高的加速比例。

  • SD模型自身特点


flash_attention2主要是针对qkv计算进行加速,sd的推理过程中还有很多别的计算。推理过程中,进行采样(去噪),具有大量的计算,qkv计算只是推理计算的一部分。对于大图,计算量也更大,qkv的计算比例也更大,所以可以得到更多的加速效果。



SD模型的网络结构:


  • SD社区代码特点

目前方法对于sd的提速,主要针对stable diffusion的神经网络本身。但是webui还有别的耗时:1.webui生图以外,还需要进行很多后处理,比如序列化反序列化、图片后处理、转换格式、传递图片等。2.webui是一个社区项目,兼容的功能非常多、而且杂,里面各种判断逻辑。这些操作拖慢了速度,端到端速度表现一般。

  • 显卡性能特殊性


加速比例,对于不同的GPU效果不一样。3090这张卡比较特殊,计算性能好,但是显存的读取速度很差。所以在batchsize小的时候,性能卡在显存读取速度上,加速比高不起来。如果想看到更高的加速比,可以试试加大batchsize,这样diffusion占用的时间变多,网络和反序列化消耗时间的占比变小。如果用A系列卡效果会好一些。A10的性能比3090差,和他的显存读取速度匹配,加速方法对算法的优化比较符合卡的特点。
FlashAttention-2与fastunet对于AIGC联合加速

为了进一步优化aigc生图效率,使用webui更快速地进行加速,我们针对diffusion model特点,通过fastunet与FlashAttention-2结合的方式进行加速,并取得了相对于flash1已有加速效果的大于40%的提速。

 实验环境


NVIDIA A10, CUDA Version: 11.4, webui-1.5.1, eas推理平台

 加速效果


xformers(flash2)+fastunet


文生图(512*512) 文生图(512*512)(batchsize=4)
unet耗时(s) 1111

3

3

3

3

unet耗时(it/s)

(step = 20)

17.06it/s

18.22it/s

17.36it/s

16.43it/s

6.26it/s

6.27it/s

6.25it/s

6.25it/s


相对于xformers(flash1),xformers(flash2)+fastunet提速:


unet过程提速
文生图加速(一次生成1图) (17.26-11.23)/11.23=54%
文生图加速(一次生成4图) (6.26-4.33)/4.33=45%


加速效果:flash attention2 + fastunet > flash attention2 > flash attention1
 精度比较

使用xformers(flash2)+fastunet加速方法,AIGC生成图像,结果符合预期,无精度损失。


文生图(512*512)_ouput1 文生图(512*512)_ouput2


 AIGC加速分析


生图过程主要有两部分耗时:controlnet与unet

旧方法:xformers 0.0.20,使用flash attention1加速sd(unet+controlnet)

新方法:1.当前的fastunet只加速unet里的attention(换为flash attention2)。2.xformers0.0.21加速包括controlnet在内的所有attention(换为flash attention2)。3.fastunett还对其他算子也做了一些fuse操作,也起到了加速效果。



fastunet和xformers0.0.21加速的底层逻辑,都是使用flash attention2优化attention。fastunet和xformers0.0.21叠加使用,可以最大程度起到加速效果。新的加速方法主要针对attention计算进行优化,所以在unet及其attention部分会有更高比例的加速。


总结与讨论


近年来,让 Transformers 能够处理更长的序列长度一直备受关注。这一发展有助于提升语言建模和高分辨率图像理解的能力,并为音频和视频生成等新的应用场景带来了机遇。FlashAttention方法使得使用更长的序列计算注意力成为可能,并通过线性级别的增长来节省内存并加速计算。这一方法为处理长序列的Transformer模型提供了一种有效的解决方案。最新提出的FlashAttention-2,也进一步实现了attention计算速度的大幅度提升。



当我们一直在关注GPU显存大小以及计算能力的时候,FlashAttention关注了GPU显存以外的SRAM,从而优化attention计算。也为我们解决问题提供了思考,即在主流关注的技术点以外,还有一些被忽视的但依旧可以解决问题的思路。面对实际效果与理论效果的差距,FlashAttention-2则进一步找到gap原因,通过关注矩阵运算、序列并行、工作分区等问题,优化计算效果。这也提醒我们,对于性能问题的解决,从软硬件结合的角度出发,才能更充分的解决问题。



在AIGC领域的生图任务中,使用diffusion model进行相关计算,需要大量时间完成生图过程。所以,通过FlashAttention-2等多种加速方法进一步提升AIGC的生图效率,具有深刻意义。我们团队致力于家装行业AIGC进行相关研发,以提高家装AI模型的效果。我们希望与对此方向感兴趣的同学一起探讨和交流。


团队介绍


我们是淘天集团-场景智能技术团队,一支专注于通过AI和3D技术驱动商业创新的技术团队, 依托大淘宝丰富的业务形态和海量的用户、数据, 致力于为消费者提供创新的场景化导购体验, 为商家提供高效的场景化内容创作工具, 为淘宝打造围绕家的场景的第一消费入口。我们不断探索并实践新的技术, 通过持续的技术创新和突破,创新用户导购体验, 提升商家内容生产力, 让用户享受更好的消费体验, 让商家更高效、低成本地经营。

相关文章
|
7月前
|
安全 算法 网络协议
解析:HTTPS通过SSL/TLS证书加密的原理与逻辑
HTTPS通过SSL/TLS证书加密,结合对称与非对称加密及数字证书验证实现安全通信。首先,服务器发送含公钥的数字证书,客户端验证其合法性后生成随机数并用公钥加密发送给服务器,双方据此生成相同的对称密钥。后续通信使用对称加密确保高效性和安全性。同时,数字证书验证服务器身份,防止中间人攻击;哈希算法和数字签名确保数据完整性,防止篡改。整个流程保障了身份认证、数据加密和完整性保护。
|
6月前
|
人工智能 API 语音技术
HarmonyOS Next~鸿蒙AI功能开发:Core Speech Kit与Core Vision Kit的技术解析与实践
本文深入解析鸿蒙操作系统(HarmonyOS)中的Core Speech Kit与Core Vision Kit,探讨其在AI功能开发中的核心能力与实践方法。Core Speech Kit聚焦语音交互,提供语音识别、合成等功能,支持多场景应用;Core Vision Kit专注视觉处理,涵盖人脸检测、OCR等技术。文章还分析了两者的协同应用及生态发展趋势,展望未来AI技术与鸿蒙系统结合带来的智能交互新阶段。
351 31
|
6月前
|
机器学习/深度学习 数据可视化 PyTorch
深入解析图神经网络注意力机制:数学原理与可视化实现
本文深入解析了图神经网络(GNNs)中自注意力机制的内部运作原理,通过可视化和数学推导揭示其工作机制。文章采用“位置-转移图”概念框架,并使用NumPy实现代码示例,逐步拆解自注意力层的计算过程。文中详细展示了从节点特征矩阵、邻接矩阵到生成注意力权重的具体步骤,并通过四个类(GAL1至GAL4)模拟了整个计算流程。最终,结合实际PyTorch Geometric库中的代码,对比分析了核心逻辑,为理解GNN自注意力机制提供了清晰的学习路径。
456 7
深入解析图神经网络注意力机制:数学原理与可视化实现
|
6月前
|
机器学习/深度学习 缓存 自然语言处理
深入解析Tiktokenizer:大语言模型中核心分词技术的原理与架构
Tiktokenizer 是一款现代分词工具,旨在高效、智能地将文本转换为机器可处理的离散单元(token)。它不仅超越了传统的空格分割和正则表达式匹配方法,还结合了上下文感知能力,适应复杂语言结构。Tiktokenizer 的核心特性包括自适应 token 分割、高效编码能力和出色的可扩展性,使其适用于从聊天机器人到大规模文本分析等多种应用场景。通过模块化设计,Tiktokenizer 确保了代码的可重用性和维护性,并在分词精度、处理效率和灵活性方面表现出色。此外,它支持多语言处理、表情符号识别和领域特定文本处理,能够应对各种复杂的文本输入需求。
762 6
深入解析Tiktokenizer:大语言模型中核心分词技术的原理与架构
|
6月前
|
机器学习/深度学习 人工智能 自然语言处理
AIGC技术发展与应用实践(一文读懂AIGC)
AIGC(人工智能生成内容)是利用AI技术生成文本、图像、音频、视频等内容的重要领域。其发展历程包括初期探索、应用拓展和深度融合三大阶段,核心技术涵盖数据收集、模型训练、内容生成、质量评估及应用部署。AIGC在内容创作、教育、医疗、游戏、商业等领域广泛应用,未来将向更大规模、多模态融合和个性化方向发展。但同时也面临伦理法律和技术瓶颈等挑战,需在推动技术进步的同时加强规范与监管,以实现健康可持续发展。
|
6月前
|
传感器 人工智能 监控
反向寻车系统怎么做?基本原理与系统组成解析
本文通过反向寻车系统的核心组成部分与技术分析,阐述反向寻车系统的工作原理,适用于适用于商场停车场、医院停车场及火车站停车场等。如需获取智慧停车场反向寻车技术方案前往文章最下方获取,如有项目合作及技术交流欢迎私信作者。
403 2
|
6月前
|
缓存 边缘计算 安全
阿里云CDN:全球加速网络的实践创新与价值解析
在数字化浪潮下,用户体验成为企业竞争力的核心。阿里云CDN凭借技术创新与全球化布局,提供高效稳定的加速解决方案。其三层优化体系(智能调度、缓存策略、安全防护)确保低延迟和高命中率,覆盖2800+全球节点,支持电商、教育、游戏等行业,帮助企业节省带宽成本,提升加载速度和安全性。未来,阿里云CDN将继续引领内容分发的行业标准。
381 7
|
7月前
|
Java 数据库 开发者
详细介绍SpringBoot启动流程及配置类解析原理
通过对 Spring Boot 启动流程及配置类解析原理的深入分析,我们可以看到 Spring Boot 在启动时的灵活性和可扩展性。理解这些机制不仅有助于开发者更好地使用 Spring Boot 进行应用开发,还能够在面对问题时,迅速定位和解决问题。希望本文能为您在 Spring Boot 开发过程中提供有效的指导和帮助。
661 12
|
6月前
|
机器学习/深度学习 人工智能 自然语言处理
DeepSeek 实践应用解析:合力亿捷智能客服迈向 “真智能” 时代
DeepSeek作为人工智能领域的创新翘楚,凭借领先的技术实力,在智能客服领域掀起变革。通过全渠道智能辅助、精准对话管理、多语言交互、智能工单处理、个性化推荐、情绪分析及反馈监控等功能,大幅提升客户服务效率和质量,助力企业实现卓越升级,推动智能化服务发展。
242 1
|
6月前
|
机器学习/深度学习 人工智能 监控
鸿蒙赋能智慧物流:AI类目标签技术深度解析与实践
在数字化浪潮下,物流行业面临变革,传统模式的局限性凸显。AI技术为物流转型升级注入动力。本文聚焦HarmonyOS NEXT API 12及以上版本,探讨如何利用AI类目标签技术提升智慧物流效率、准确性和成本控制。通过高效数据处理、实时监控和动态调整,AI技术显著优于传统方式。鸿蒙系统的分布式软总线技术和隐私保护机制为智慧物流提供了坚实基础。从仓储管理到运输监控再到配送优化,AI类目标签技术助力物流全流程智能化,提高客户满意度并降低成本。开发者可借助深度学习框架和鸿蒙系统特性,开发创新应用,推动物流行业智能化升级。
181 1

推荐镜像

更多
  • DNS