掩码注意力(Causal Attention)是生成式模型的核心技术,它传统自注意力机制有根本的不同,掩码注意力限制模型只能关注当前位置之前的tokens,确保了自回归生成的因果性。
自注意力的掩码
自注意力机制在Transformer编码器和BERT等模型中广泛应用。这种机制的特点是每个token都能访问序列中的所有其他tokens,包括前面和后面的位置。这种双向注意力让模型能够充分利用上下文信息,将静态词嵌入转换为富含语境的动态表示。
而掩码注意力作为解码器的关键组件,人为地阻断了对未来tokens的访问。这种单向约束虽然看起来是限制,实际上正是语言生成任务的核心要求——模型必须基于已有的上下文来预测下一个词,而不能"偷看"答案。
Pytorch实现
实现掩码注意力需要五个关键步骤:
先看基础的类结构定义。这里需要为Query、Key、Value分别创建线性变换层,同时初始化一个上三角掩码矩阵:
import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
register_buffer
这个方法很关键。它确保掩码矩阵会跟随模型在CPU和GPU之间移动,但不会作为可训练参数参与梯度更新。
然后就是前向传播的第一步,计算注意力分数。这部分和标准自注意力完全一样:
import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
def forward(self,x):
batch,num_tokens,in_dim = x.shape
vec_q=self.w_q(x)
vec_K=self.w_k(x)
vec_v=self.w_v(x)
[#attention](#attention)_score
attention_score= vec_q @ vec_k.transpose(1,2) # 记住我们在处理批量数据
下面就是最关键的掩码操作。在这一步
masked_fill_
函数会将掩码为True的位置填充为负无穷大,这样在后续softmax操作中这些位置的权重就会变成0:
import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
def forward(self,x):
batch,num_tokens,in_dim = x.shape
vec_q=self.w_q(x)
vec_K=self.w_k(x)
vec_v=self.w_v(x)
[#attention](#attention)_score
attention_score= vec_q @ vec_k.transpose(1,2)
[#重要的代码行](#重要的代码行) #########
attention_score.masked_fill_(mask.bool()[:num_tokens,:num_tokens],-torch.inf)
然后是就是标准的缩放和softmax归一化。这里除法运算中的
vec_k.shape[-1]
是Key向量的维度,这个缩放因子能够稳定梯度:
import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
def forward(self,x):
batch,num_tokens,in_dim = x.shape
vec_q=self.w_q(x)
vec_K=self.w_k(x)
vec_v=self.w_v(x)
[#attention](#attention)_score
attention_score= vec_q @ vec_k.transpose(1,2)
[#重要的代码行](#重要的代码行) #########
attention_score.masked_fill_(mask.bool()[:num_tokens:num_tokens],-torch.inf)
[#通过attention](#通过attention)_weight进行缩放
attention_weight=torch.softmax(attention_score/vec_k.shape[-1],dim=-1)
最后加入dropout防止过拟合(也可以不加,现在的模型基本上不会dropout了,但是为了演示,我们可以在这里加入dropout),并与Value向量相乘得到最终的上下文表示:
import torch.nn as nn
import torch
class CasualAttention(nn.Module):
def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):
super().__init__()
self.w_q=nn.Linear(in_put,out_dim,bias=bias)
self.w_k=nn.Linear(in_put,out_dim,bias=bias)
self.w_v=nn.Linear(in_put,out_dim,bias=bias)
self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
def forward(self,x):
batch,num_tokens,in_dim = x.shape
vec_q=self.w_q(x)
vec_K=self.w_k(x)
vec_v=self.w_v(x)
[#attention](#attention)_score
attention_score= vec_q @ vec_k.transpose(1,2)
[#重要的代码行](#重要的代码行) #########
attention_score.masked_fill_(mask.bool()[:num_tokens:num_tokens],-torch.inf)
[#通过attention](#通过attention)_weight进行缩放
attention_weight=torch.softmax(attention_score/vec_k.shape[-1],dim=-1)
drop_out=self.Drop(attention_weight)
return drop_out @ vec_v
最后我们来详细解释一下这行代码:
attention_score.masked_fill_(mask.bool()[:num_tokens,num_tokens],-torch.inf)
整个掩码操作分几个部分:首先计算原始的注意力分数矩阵,然后从预先注册的上三角掩码中切取对应大小的子矩阵。
mask.bool()
将0/1矩阵转换为布尔型,这样
masked_fill_
函数就将这些位置填充负无穷。
因为负无穷,所以当这些位置经过softmax函数时,exp(-∞)会趋向于0,从而实现了完全屏蔽未来tokens的效果。切片操作
[:num_tokens,num_tokens]
处理了不同序列长度的情况,因为上下文窗口是固定的,但实际输入序列长度可能变化。
总结
这种掩码机制让GPT等模型能够逐词生成文本,每次预测都只基于已经生成的内容,这正是自回归语言模型的精髓所在。通过一个上三角掩码矩阵,就能让模型在训练时学会"单向思考",这种设计的巧妙之处在于它完美平衡了计算效率和生成质量。
从技术实现角度来看,整个过程其实就是在标准自注意力基础上加了一步
masked_fill_
操作。但正是这简单的一步,让模型具备了真正的文本生成能力。相比之下,BERT等双向模型虽然在理解任务上表现出色,但在生成任务上就显得力不从心。
掌握了掩码注意力,你就理解了GPT、LLaMA等主流生成模型的核心工作原理。下次看到这些模型的论文或代码时,相信你会有更深刻的认识。
https://avoid.overfit.cn/post/1eaccf4c67f74b27839e3c5b2372f23c
作者:VIGNESHWARAN