class SelfAttention(torch.nn.Module): """ 自注意力的逻辑,包含四部分: + 从输入计算 QKV, + 对 QKV 分头, + 从 QKV 计算 O(在`CoreAttention`里面), + 从 O 计算输出 """ def __init__(self, config: ChatGLMConfig, layer_number, device=None): super(SelfAttention, self).__init__() # 层的序号 self.layer_number = max(1, layer_number) # ProjSize:就是没有开启 MQA 情况下的 QKV 的尺寸 # 等于 NHead * HeadSize,和原始的 HidSize 可能有不同· self.projection_size = config.kv_channels * config.num_attention_heads # HeadSize = ProjSize // NHead self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads # NHead self.num_attention_heads_per_partition = config.num_attention_heads # 控制是否启用MQA self.multi_query_attention = config.multi_query_attention # 如果不启用 MQA,QKVSize 就是三倍的 ProjSize self.qkv_hidden_size = 3 * self.projection_size if self.multi_query_attention: # 如果启用了 MQA # NGroup self.num_multi_query_groups_per_partition = config.multi_query_group_num # QKVSize 等于 ProjSize(Q) + 2 * HeadSize * NGroup (KV) self.qkv_hidden_size = ( self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num ) # 将输入映射成 QKV 的线性层 self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, bias=config.add_bias_linear or config.add_qkv_bias, device=device, **_config_to_kwargs(config) ) # 用于从 QKV 计算 O 的核心模块 self.core_attention = CoreAttention(config, self.layer_number) # 用于从 O 计算输出的线性层 self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config) ) def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): if self.multi_query_attention: num_attention_heads = self.num_multi_query_groups_per_partition else: num_attention_heads = self.num_attention_heads_per_partition return torch.empty( inference_max_sequence_len, batch_size, num_attention_heads, self.hidden_size_per_attention_head, dtype=dtype, device=device, ) def forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ): # 输入隐藏状态尺寸为 [SeqLen, BatchSize, HidSize] # 使用输入计算 QKV mixed_x_layer = self.query_key_value(hidden_states) if self.multi_query_attention: # 如果开启了 MQA,将 QKV 按照最后一维分割 # 得到 Q [SeqLen, BatchSize, ProjSize] # 和 K/V [SeqLen, BatchSize, NGroup * HeadSize] (query_layer, key_layer, value_layer) = mixed_x_layer.split( [ self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, ], dim=-1, ) # 对 Q 分头,变形为 [SeqLen, BatchSize, NHead, HeadSize] query_layer = query_layer.view( query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) ) # 对 K 分头,变形为 [SeqLen, BatchSize, NGroup, HeadSize] key_layer = key_layer.view( key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) ) # 对 V 分头,变形为 [SeqLen, BatchSize, NGroup, HeadSize] value_layer = value_layer.view( value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) ) else: # 变形为 [SeqLen, BatchSize, NHead, 3 * HeadSize] new_tensor_shape = mixed_x_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # 将 QKV 最后一维平分三份,得到 Q/K/V # 尺寸为 [SeqLen, BatchSize, NHead, HeadSize] (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) # 应用 ROPE if rotary_pos_emb is not None: query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) # 如果传入了 KVCache # 拆分为 KCache 和 VCache # 每个形状为 [CacheLen, BatchSize, NGroup, HeadSize] # 分别添加到 K 和 V 前面 if kv_cache is not None: cache_k, cache_v = kv_cache key_layer = torch.cat((cache_k, key_layer), dim=0) value_layer = torch.cat((cache_v, value_layer), dim=0) # 如果设置了 UseCache,则返回 KV if use_cache: kv_cache = (key_layer, value_layer) else: kv_cache = None # MQA 模式下,给 K 和 V 广播到 Q 的形状 # [..., NGroup, ...] => [..., NGroup, 1, ...] => # [..., NGroup, NHead // NGroup, ...] => # [..., NHead, ...] if self.multi_query_attention: # K 变形为 [CacheSeqLen, BatchSize, NGroup, 1, HeadSize] key_layer = key_layer.unsqueeze(-2) # K 广播为 [CacheSeqLen, BatchSize, NGroup, NHead // NGroup, HeadSize] # NHead // NGroup 是每一组的头部数量 # 相当于把最后一维复制了 NHead // NGroup 等份 key_layer = key_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 ) # K 变形为 [CacheSeqLen, BatchSize, NHead, HeadSize] key_layer = key_layer.contiguous().view( key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) ) # V 变形为 [CacheSeqLen, BatchSize, NGroup, 1, HeadSize] value_layer = value_layer.unsqueeze(-2) # V 广播为 [CacheSeqLen, BatchSize, NGroup, NHead // NGroup, HeadSize] value_layer = value_layer.expand( -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 ) # V 变形为 [CacheSeqLen, BatchSize, NHead, HeadSize] value_layer = value_layer.contiguous().view( value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) ) # 将 Q K V 和掩码数组传入核心模块,得到 O # 尺寸为 [SeqLen, BatchSize, ProjSize] context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # 使用 O 计算输出,尺寸为 [SeqLen, BatchSize, HidSize] output = self.dense(context_layer) return output, kv_cache
CoreAttention
class CoreAttention(torch.nn.Module): ''' 包含了从分头的 QKV 计算 O 的逻辑 ''' def __init__(self, config: ChatGLMConfig, layer_number): super(CoreAttention, self).__init__() # 控制 QK 是否缩放 self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling # 控制注意力矩阵是否转为 FP32 self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 # 缩放模式下必须为 FP32 if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True # 确保层序号大于等于 1 self.layer_number = max(1, layer_number) # ProjSize = HeadSize * NHead projection_size = config.kv_channels * config.num_attention_heads # ProjSize self.hidden_size_per_partition = projection_size # HeadSize = HeadSize // NHead self.hidden_size_per_attention_head = projection_size // config.num_attention_heads # NHead self.num_attention_heads_per_partition = config.num_attention_heads # 如果定义了 QK 缩放 # 系数就是层序号 # d = 系数 * HeadSize # 否则 d = HeadSize coeff = None self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) if self.apply_query_key_layer_scaling: coeff = self.layer_number self.norm_factor *= coeff self.coeff = coeff # 用于注意力矩阵的 Dropout self.attention_dropout = torch.nn.Dropout(config.attention_dropout) def forward(self, query_layer, key_layer, value_layer, attention_mask): # Q:[SeqLen, BatchSize, NHead, HeadSize] # K:[CacheSeqLen, BatchSize, NHead, HeadSize] # V:[CacheSeqLen, BatchSize, NHead, HeadSize] # 如果 PyTorch 版本大于 2,直接调用内置函数 pytorch_major_version = int(torch.__version__.split('.')[0]) if pytorch_major_version >= 2: query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, is_causal=True) else: if attention_mask is not None: attention_mask = ~attention_mask context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.reshape(*new_context_layer_shape) else: # 否则自己实现计算逻辑 # 定义注意力矩阵的尺寸 # [BatchSize, NHead, Seqlen, CacheSeqLen] output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) # 合并 Q 中间两维,[Seqlen, BatchSize * NHead, HeadSize] query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) # 合并 K 中间两维,[CacheSeqlen, BatchSize * NHead, HeadSize] key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) # 定义缓冲张量,形状和注意力矩阵相同 # [BatchSize * NHead, SeqLen, CacheSeqLen] matmul_input_buffer = torch.empty( output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, device=query_layer.device ) # 交换 Q 前两维,[BatchSize * NHead, SeqLen, HeadSize] # 交换 K 前两维和后两维,[BatchSize * NHead, HeadSize, CacheSeqLen] # 计算原始注意力矩阵 A = Q @ K / d # beta=0 所以不受缓冲张量的影响 matmul_result = torch.baddbmm( matmul_input_buffer, query_layer.transpose(0, 1), # [b * np, sq, hn] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] beta=0.0, alpha=(1.0 / self.norm_factor), ) # 拆分 A 第一维,[BatchSize, NHead, Seqlen, CacheSeqLen] attention_scores = matmul_result.view(*output_size) # 如果定义了...,将其转为 FP32 if self.attention_softmax_in_fp32: attention_scores = attention_scores.float() # 如果定义了系数,将其相乘 if self.coeff is not None: attention_scores = attention_scores * self.coeff # 如果传入了掩码矩阵,并且注意力矩阵后两维相等(也就是没有KVCache) if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: # 将掩码矩阵初始化为全1矩阵 # 形状为 [BatchSize, 1, Seqlen, CacheSeqLen] attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], device=attention_scores.device, dtype=torch.bool) # 只保留下三角元素,上三角置 0 attention_mask.tril_() # 翻转矩阵,使上三角为 True,下三角为 False attention_mask = ~attention_mask # 如果传入了掩码矩阵,将其非零位置的元素设为 -inf if attention_mask is not None: attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) # 注意力矩阵应用 SoftMax attention_probs = F.softmax(attention_scores, dim=-1) # 转回输入的数据类型 attention_probs = attention_probs.type_as(value_layer) # 对注意力矩阵应用 Dropout attention_probs = self.attention_dropout(attention_probs) # 定义 O 的尺寸 [BatchSize, NHead, SeqLen, HeadSize] output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) # 合并 V 中间两维,[CacheSeqLen, BatchSize * NHead, HeadSize] value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) # 合并 A 前两维,[BatchSize * NHead, SeqLen, CacheSeqLen] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) # 交换 V 前两维,[BatchSize * NHead, CacheSeqLen, HeadSize] # 计算 O = A @ V context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) # 拆分 O 前两维, [BatchSize, NHead, SeqLen, HeadSize] context_layer = context_layer.view(*output_size) # 将 O 转置为 [SeqLen, BatchSize, NHead, HeadSize] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # 合并 O 后两维,[SeqLen, BatchSize, ProjSize] new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) # 返回 O return context_layer