(Beta)使用缩放点积注意力(SDPA)实现高性能 Transformer
原文:
pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
译者:飞龙
注意
点击这里下载完整示例代码
作者: Driss Guessous
摘要
在本教程中,我们想要强调一个新的torch.nn.functional
函数,可以帮助实现 Transformer 架构。该函数被命名为torch.nn.functional.scaled_dot_product_attention
。有关该函数的详细描述,请参阅PyTorch 文档。该函数已经被整合到torch.nn.MultiheadAttention
和torch.nn.TransformerEncoderLayer
中。
概述
在高层次上,这个 PyTorch 函数根据论文Attention is all you need中的定义,计算查询、键和值之间的缩放点积注意力(SDPA)。虽然这个函数可以使用现有函数在 PyTorch 中编写,但融合实现可以比朴素实现提供更大的性能优势。
融合实现
对于 CUDA 张量输入,该函数将分派到以下实现之一:
- FlashAttention:具有 IO 感知的快速和内存高效的精确注意力
- 内存高效注意力
- 一个在 C++中定义的 PyTorch 实现
注意
本教程需要 PyTorch 2.0.0 或更高版本。
import torch import torch.nn as nn import torch.nn.functional as F device = "cuda" if torch.cuda.is_available() else "cpu" # Example Usage: query, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device) F.scaled_dot_product_attention(query, key, value)
tensor([[[-1.3321, -0.3489, 0.3015, -0.3912, 0.9867, 0.3137, -0.0691, -1.2593], [-1.0882, 0.2506, 0.6491, 0.1360, 0.5238, -0.2448, -0.0820, -0.6171], [-1.0012, 0.3990, 0.6441, -0.0277, 0.5325, -0.2564, -0.0607, -0.6404]], [[ 0.6091, 0.0708, 0.6188, 0.3252, -0.1598, 0.4197, -0.2335, 0.0630], [ 0.5285, 0.3890, -0.2649, 0.3706, -0.3839, 0.1963, -0.6242, 0.2312], [ 0.4048, 0.0762, 0.3777, 0.4689, -0.2978, 0.2754, -0.6429, 0.1037]]], device='cuda:0')
显式调度控制
虽然函数将隐式分派到三种实现之一,但用户也可以通过使用上下文管理器来显式控制分派。这个上下文管理器允许用户显式禁用某些实现。如果用户想确保函数确实使用了最快的实现来处理他们特定的输入,上下文管理器可以用来测量性能。
# Lets define a helpful benchmarking function: import torch.utils.benchmark as benchmark def benchmark_torch_function_in_microseconds(f, *args, **kwargs): t0 = benchmark.Timer( stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} ) return t0.blocked_autorange().mean * 1e6 # Lets define the hyper-parameters of our input batch_size = 32 max_sequence_len = 1024 num_heads = 32 embed_dimension = 32 dtype = torch.float16 query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype) print(f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") # Lets explore the speed of each of the 3 implementations from torch.backends.cuda import sdp_kernel, SDPBackend # Helpful arguments mapper backend_map = { SDPBackend.MATH: {"enable_math": True, "enable_flash": False, "enable_mem_efficient": False}, SDPBackend.FLASH_ATTENTION: {"enable_math": False, "enable_flash": True, "enable_mem_efficient": False}, SDPBackend.EFFICIENT_ATTENTION: { "enable_math": False, "enable_flash": False, "enable_mem_efficient": True} } with sdp_kernel(**backend_map[SDPBackend.MATH]): print(f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): try: print(f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") except RuntimeError: print("FlashAttention is not supported. See warnings for reasons.") with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): try: print(f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds") except RuntimeError: print("EfficientAttention is not supported. See warnings for reasons.")
The default implementation runs in 2263.405 microseconds The math implementation runs in 19254.524 microseconds The flash attention implementation runs in 2262.901 microseconds The memory efficient implementation runs in 4143.146 microseconds
硬件依赖
取决于您在哪台机器上运行上述单元格以及可用的硬件,您的结果可能会有所不同。- 如果您没有 GPU 并且在 CPU 上运行,则上下文管理器将不起作用,所有三次运行应该返回类似的时间。- 取决于您的显卡支持的计算能力,闪光注意力或内存效率可能会失败。
因果自注意力
以下是受Andrej Karpathy NanoGPT仓库启发的多头因果自注意力块的示例实现。
class CausalSelfAttention(nn.Module): def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0): super().__init__() assert embed_dimension % num_heads == 0 # key, query, value projections for all heads, but in a batch self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias) # output projection self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias) # regularization self.dropout = dropout self.resid_dropout = nn.Dropout(dropout) self.num_heads = num_heads self.embed_dimension = embed_dimension # Perform causal masking self.is_causal = is_causal def forward(self, x): # calculate query, key, values for all heads in batch and move head forward to be the batch dim query_projected = self.c_attn(x) batch_size = query_projected.size(0) embed_dim = query_projected.size(2) head_dim = embed_dim // (self.num_heads * 3) query, key, value = query_projected.chunk(3, -1) query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2) if self.training: dropout = self.dropout is_causal = self.is_causal else: dropout = 0.0 is_causal = False y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal) y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim) y = self.resid_dropout(self.c_proj(y)) return y num_heads = 8 heads_per_dim = 64 embed_dimension = num_heads * heads_per_dim dtype = torch.float16 model = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to("cuda").to(dtype).eval() print(model)
CausalSelfAttention( (c_attn): Linear(in_features=512, out_features=1536, bias=False) (c_proj): Linear(in_features=512, out_features=512, bias=False) (resid_dropout): Dropout(p=0.1, inplace=False) )
NestedTensor
和密集张量支持
SDPA 支持 NestedTensor
和 Dense 张量输入。NestedTensors
处理输入为批量可变长度序列的情况,无需将每个序列填充到批量中的最大长度。有关 NestedTensors
的更多信息,请参阅 torch.nested 和 NestedTensors 教程。
import random def generate_rand_batch( batch_size, max_sequence_len, embed_dimension, pad_percentage=None, dtype=torch.float16, device="cuda", ): if not pad_percentage: return ( torch.randn( batch_size, max_sequence_len, embed_dimension, dtype=dtype, device=device, ), None, ) # Random sequence lengths seq_len_list = [ int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01))) for _ in range(batch_size) ] # Make random entry in the batch have max sequence length seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len return ( torch.nested.nested_tensor( [ torch.randn(seq_len, embed_dimension, dtype=dtype, device=device) for seq_len in seq_len_list ] ), seq_len_list, ) random_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device) random_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device) # Currently the fused implementations don't support ``NestedTensor`` for training model.eval() with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): try: print(f"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds") print(f"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds") except RuntimeError: print("FlashAttention is not supported. See warnings for reasons.")
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/nested/__init__.py:166: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:177.) Random NT runs in 560.000 microseconds Random Dense runs in 938.743 microseconds
使用 torch.compile
进行 SDPA
随着 PyTorch 2.0 的发布,引入了一个名为 torch.compile()
的新功能,可以在 eager 模式下提供显著的性能改进。缩放点积注意力与 torch.compile()
完全兼容。为了演示这一点,让我们使用 torch.compile()
编译 CausalSelfAttention
模块,并观察结果性能的提升。
batch_size = 32 max_sequence_len = 256 x = torch.rand(batch_size, max_sequence_len, embed_dimension, device=device, dtype=dtype) print( f"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds") compiled_model = torch.compile(model) # Let's compile it compiled_model(x) print( f"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds")
The non compiled module runs in 407.788 microseconds The compiled module runs in 521.239 microseconds
确切的执行时间取决于机器,但对于我的结果是:非编译模块运行时间为 166.616 微秒,编译模块运行时间为 166.726 微秒。这不是我们预期的结果。让我们深入一点。PyTorch 带有一个令人惊叹的内置分析器,您可以使用它来检查代码的性能特征。
from torch.profiler import profile, record_function, ProfilerActivity activities = [ProfilerActivity.CPU] if device == 'cuda': activities.append(ProfilerActivity.CUDA) with profile(activities=activities, record_shapes=False) as prof: with record_function(" Non-Compilied Causal Attention"): for _ in range(25): model(x) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) with profile(activities=activities, record_shapes=False) as prof: with record_function("Compiled Causal Attention"): for _ in range(25): compiled_model(x) print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) # For even more insights, you can export the trace and use ``chrome://tracing`` to view the results # # .. code-block:: python # # prof.export_chrome_trace("compiled_causal_attention_trace.json").
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Non-Compilied Causal Attention 18.51% 2.124ms 75.85% 8.703ms 8.703ms 0.000us 0.00% 11.033ms 11.033ms 1 aten::matmul 2.23% 256.000us 27.21% 3.122ms 62.440us 0.000us 0.00% 8.156ms 163.120us 50 aten::mm 19.17% 2.200ms 23.15% 2.656ms 53.120us 7.752ms 76.53% 8.156ms 163.120us 50 aten::linear 1.83% 210.000us 30.51% 3.501ms 70.020us 0.000us 0.00% 7.846ms 156.920us 50 ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.554ms 54.83% 5.554ms 222.160us 25 aten::scaled_dot_product_attention 1.97% 226.000us 18.83% 2.161ms 86.440us 0.000us 0.00% 2.877ms 115.080us 25 aten::_scaled_dot_product_flash_attention 3.51% 403.000us 16.86% 1.935ms 77.400us 0.000us 0.00% 2.877ms 115.080us 25 aten::_flash_attention_forward 4.62% 530.000us 12.10% 1.388ms 55.520us 2.377ms 23.47% 2.877ms 115.080us 25 void pytorch_flash::flash_fwd_kernel<pytorch_flash::... 0.00% 0.000us 0.00% 0.000us 0.000us 2.377ms 23.47% 2.377ms 95.080us 25 ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 2.198ms 21.70% 2.198ms 87.920us 25 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 11.474ms Self CUDA time total: 10.129ms ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Compiled Causal Attention 9.79% 1.158ms 93.81% 11.091ms 11.091ms 0.000us 0.00% 10.544ms 10.544ms 1 Torch-Compiled Region 8.51% 1.006ms 82.19% 9.717ms 388.680us 0.000us 0.00% 10.544ms 421.760us 25 CompiledFunction 41.11% 4.861ms 72.93% 8.622ms 344.880us 0.000us 0.00% 10.544ms 421.760us 25 aten::mm 7.96% 941.000us 12.70% 1.502ms 30.040us 7.755ms 76.49% 7.843ms 156.860us 50 ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_tn 0.00% 0.000us 0.00% 0.000us 0.000us 5.556ms 54.80% 5.556ms 222.240us 25 aten::_scaled_dot_product_flash_attention 2.30% 272.000us 15.12% 1.788ms 71.520us 0.000us 0.00% 2.701ms 108.040us 25 aten::_flash_attention_forward 4.58% 541.000us 11.52% 1.362ms 54.480us 2.383ms 23.51% 2.701ms 108.040us 25 void pytorch_flash::flash_fwd_kernel<pytorch_flash::... 0.00% 0.000us 0.00% 0.000us 0.000us 2.383ms 23.51% 2.383ms 95.320us 25 ampere_fp16_s1688gemm_fp16_128x128_ldg8_f2f_stages_3... 0.00% 0.000us 0.00% 0.000us 0.000us 2.199ms 21.69% 2.199ms 87.960us 25 cudaStreamIsCapturing 0.24% 28.000us 0.24% 28.000us 1.120us 222.000us 2.19% 222.000us 8.880us 25 ------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ Self CPU time total: 11.823ms Self CUDA time total: 10.138ms
前面的代码片段生成了一个报告,列出了消耗最多 GPU 执行时间的前 10 个 PyTorch 函数,分别针对编译和非编译模块。分析显示,GPU 上花费的大部分时间集中在两个模块的相同一组函数上。这里的原因是torch.compile
非常擅长消除与 PyTorch 相关的框架开销。如果您的模型启动了大型、高效的 CUDA 内核,比如这里的CausalSelfAttention
,那么 PyTorch 的开销就可以被隐藏起来。
实际上,您的模块通常不是由单个CausalSelfAttention
块组成的。在与Andrej Karpathy NanoGPT存储库进行实验时,编译模块的时间从每个训练步骤的6090.49ms
降至3273.17ms
!这是在 NanoGPT 训练莎士比亚数据集的提交ae3a8d5
上完成的。
结论
在本教程中,我们演示了torch.nn.functional.scaled_dot_product_attention
的基本用法。我们展示了如何使用sdp_kernel
上下文管理器来确保在 GPU 上使用特定的实现。此外,我们构建了一个简单的CausalSelfAttention
模块,可以与NestedTensor
一起使用,并且可以在 torch 中编译。在这个过程中,我们展示了如何使用性能分析工具来探索用户定义模块的性能特征。
脚本的总运行时间:(0 分钟 7.800 秒)
下载 Python 源代码:scaled_dot_product_attention_tutorial.py
下载 Jupyter 笔记本:scaled_dot_product_attention_tutorial.ipynb
知识蒸馏教程
原文:
pytorch.org/tutorials/beginner/knowledge_distillation_tutorial.html
译者:飞龙
注意
点击这里下载完整示例代码
知识蒸馏是一种技术,它可以实现从大型、计算昂贵的模型向较小的模型进行知识转移,而不会失去有效性。这使得在性能较弱的硬件上部署成为可能,从而使评估更快速、更高效。
在本教程中,我们将进行一系列旨在提高轻量级神经网络准确性的实验,使用更强大的网络作为教师。轻量级网络的计算成本和速度将保持不变,我们的干预仅关注其权重,而不是其前向传递。这项技术的应用可以在无人机或手机等设备中找到。在本教程中,我们不使用任何外部包,因为我们需要的一切都可以在torch
和torchvision
中找到。
在本教程中,您将学习:
- 如何修改模型类以提取隐藏表示并将其用于进一步计算
- 如何修改 PyTorch 中的常规训练循环,以包含额外的损失,例如用于分类的交叉熵
- 如何通过使用更复杂的模型作为教师来提高轻量级模型的性能
先决条件
- 1 GPU,4GB 内存
- PyTorch v2.0 或更高版本
- CIFAR-10 数据集(通过脚本下载并保存在名为
/data
的目录中)
import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms import torchvision.datasets as datasets # Check if GPU is available, and if not, use the CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
加载 CIFAR-10
CIFAR-10 是一个包含十个类别的流行图像数据集。我们的目标是为每个输入图像预测以下类别之一。
CIFAR-10 图像示例
输入图像是 RGB 格式的,因此它们有 3 个通道,尺寸为 32x32 像素。基本上,每个图像由 3 x 32 x 32 = 3072 个数字描述,取值范围从 0 到 255。神经网络中的常见做法是对输入进行归一化,这样做有多种原因,包括避免常用激活函数中的饱和现象,增加数值稳定性。我们的归一化过程包括沿每个通道减去平均值并除以标准差。张量“mean=[0.485, 0.456, 0.406]”和“std=[0.229, 0.224, 0.225]”已经计算出来,它们代表了 CIFAR-10 预定义子集中用作训练集的每个通道的平均值和标准差。请注意,我们也在测试集中使用这些值,而不是从头开始重新计算平均值和标准差。这是因为网络是在减去和除以上述数字产生的特征上进行训练的,我们希望保持一致性。此外,在现实生活中,我们无法计算测试集的平均值和标准差,因为根据我们的假设,在那时这些数据将不可访问。
最后,我们经常将这个留出的集合称为验证集,并在优化模型在验证集上的性能后使用一个单独的集合,称为测试集。这样做是为了避免基于单一指标的贪婪和偏见优化选择模型。
# Below we are preprocessing data for CIFAR-10\. We use an arbitrary batch size of 128. transforms_cifar = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Loading the CIFAR-10 dataset: train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar) test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)
Files already downloaded and verified Files already downloaded and verified
注意
这一部分仅适用于对快速结果感兴趣的 CPU 用户。只有在您对小规模实验感兴趣时才使用此选项。请记住,代码应该在任何 GPU 上都能运行得相当快速。从训练/测试数据集中仅选择前num_images_to_keep
张图片
#from torch.utils.data import Subset #num_images_to_keep = 2000 #train_dataset = Subset(train_dataset, range(min(num_images_to_keep, 50_000))) #test_dataset = Subset(test_dataset, range(min(num_images_to_keep, 10_000)))
#Dataloaders train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
PyTorch 2.2 中文官方教程(十七)(2)https://developer.aliyun.com/article/1482595