Transformer自回归关键技术:掩码注意力原理与PyTorch完整实现

本文涉及的产品
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时数仓Hologres,5000CU*H 100GB 3个月
实时计算 Flink 版,1000CU*H 3个月
简介: 掩码注意力是生成模型的核心,通过上三角掩码限制模型仅关注当前及之前token,确保自回归因果性。相比BERT的双向注意力,它实现单向生成,是GPT等模型逐词预测的关键机制,核心仅需一步`masked_fill_`操作。

掩码注意力(Causal Attention)是生成式模型的核心技术,它传统自注意力机制有根本的不同,掩码注意力限制模型只能关注当前位置之前的tokens,确保了自回归生成的因果性。

自注意力的掩码

自注意力机制在Transformer编码器和BERT等模型中广泛应用。这种机制的特点是每个token都能访问序列中的所有其他tokens,包括前面和后面的位置。这种双向注意力让模型能够充分利用上下文信息,将静态词嵌入转换为富含语境的动态表示。

而掩码注意力作为解码器的关键组件,人为地阻断了对未来tokens的访问。这种单向约束虽然看起来是限制,实际上正是语言生成任务的核心要求——模型必须基于已有的上下文来预测下一个词,而不能"偷看"答案。

Pytorch实现

实现掩码注意力需要五个关键步骤:

先看基础的类结构定义。这里需要为Query、Key、Value分别创建线性变换层,同时初始化一个上三角掩码矩阵:

 import torch.nn as nn  
import torch  

class CasualAttention(nn.Module):  
    def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):  
        super().__init__()  
        self.w_q=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_k=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_v=nn.Linear(in_put,out_dim,bias=bias)  
        self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)   
         self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
register_buffer

这个方法很关键。它确保掩码矩阵会跟随模型在CPU和GPU之间移动,但不会作为可训练参数参与梯度更新。

然后就是前向传播的第一步,计算注意力分数。这部分和标准自注意力完全一样:

 import torch.nn as nn  
import torch  

class CasualAttention(nn.Module):  
    def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):  
        super().__init__()  
        self.w_q=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_k=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_v=nn.Linear(in_put,out_dim,bias=bias)  
        self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)   
        self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))  

    def forward(self,x):  
        batch,num_tokens,in_dim = x.shape   
        vec_q=self.w_q(x)  
        vec_K=self.w_k(x)  
        vec_v=self.w_v(x)  

        [#attention](#attention)_score  
         attention_score= vec_q @ vec_k.transpose(1,2) # 记住我们在处理批量数据

下面就是最关键的掩码操作。在这一步

masked_fill_

函数会将掩码为True的位置填充为负无穷大,这样在后续softmax操作中这些位置的权重就会变成0:

 import torch.nn as nn  
import torch  

class CasualAttention(nn.Module):  
    def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):  
        super().__init__()  
        self.w_q=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_k=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_v=nn.Linear(in_put,out_dim,bias=bias)  
        self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)   
        self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))  

    def forward(self,x):  
        batch,num_tokens,in_dim = x.shape   
        vec_q=self.w_q(x)  
        vec_K=self.w_k(x)  
        vec_v=self.w_v(x)  

        [#attention](#attention)_score  
        attention_score= vec_q @ vec_k.transpose(1,2)  
        [#重要的代码行](#重要的代码行) #########  
         attention_score.masked_fill_(mask.bool()[:num_tokens,:num_tokens],-torch.inf)

然后是就是标准的缩放和softmax归一化。这里除法运算中的

vec_k.shape[-1]

是Key向量的维度,这个缩放因子能够稳定梯度:

 import torch.nn as nn  
import torch  

class CasualAttention(nn.Module):  
    def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):  
        super().__init__()  
        self.w_q=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_k=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_v=nn.Linear(in_put,out_dim,bias=bias)  
        self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)   
        self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))  

    def forward(self,x):  
        batch,num_tokens,in_dim = x.shape   
        vec_q=self.w_q(x)  
        vec_K=self.w_k(x)  
        vec_v=self.w_v(x)  

        [#attention](#attention)_score  
        attention_score= vec_q @ vec_k.transpose(1,2)  
        [#重要的代码行](#重要的代码行) #########  
        attention_score.masked_fill_(mask.bool()[:num_tokens:num_tokens],-torch.inf)  
        [#通过attention](#通过attention)_weight进行缩放  
         attention_weight=torch.softmax(attention_score/vec_k.shape[-1],dim=-1)

最后加入dropout防止过拟合(也可以不加,现在的模型基本上不会dropout了,但是为了演示,我们可以在这里加入dropout),并与Value向量相乘得到最终的上下文表示:

 import torch.nn as nn  
import torch  

class CasualAttention(nn.Module):  
    def __init__(self,in_Dim,out_dim,context_length,Dropout=0,bias=Fasle):  
        super().__init__()  
        self.w_q=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_k=nn.Linear(in_put,out_dim,bias=bias)  
        self.w_v=nn.Linear(in_put,out_dim,bias=bias)  
        self.Drop=nn.Dropout(Dropout) [#dropout](#dropout)   
        self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))  

    def forward(self,x):  
        batch,num_tokens,in_dim = x.shape   
        vec_q=self.w_q(x)  
        vec_K=self.w_k(x)  
        vec_v=self.w_v(x)  

        [#attention](#attention)_score  
        attention_score= vec_q @ vec_k.transpose(1,2)  
        [#重要的代码行](#重要的代码行) #########  
        attention_score.masked_fill_(mask.bool()[:num_tokens:num_tokens],-torch.inf)  
        [#通过attention](#通过attention)_weight进行缩放  
        attention_weight=torch.softmax(attention_score/vec_k.shape[-1],dim=-1)  
        drop_out=self.Drop(attention_weight)  
         return drop_out @ vec_v

最后我们来详细解释一下这行代码:

 attention_score.masked_fill_(mask.bool()[:num_tokens,num_tokens],-torch.inf)

整个掩码操作分几个部分:首先计算原始的注意力分数矩阵,然后从预先注册的上三角掩码中切取对应大小的子矩阵。

mask.bool()

将0/1矩阵转换为布尔型,这样

masked_fill_

函数就将这些位置填充负无穷。

因为负无穷,所以当这些位置经过softmax函数时,exp(-∞)会趋向于0,从而实现了完全屏蔽未来tokens的效果。切片操作

[:num_tokens,num_tokens]

处理了不同序列长度的情况,因为上下文窗口是固定的,但实际输入序列长度可能变化。

总结

这种掩码机制让GPT等模型能够逐词生成文本,每次预测都只基于已经生成的内容,这正是自回归语言模型的精髓所在。通过一个上三角掩码矩阵,就能让模型在训练时学会"单向思考",这种设计的巧妙之处在于它完美平衡了计算效率和生成质量。

从技术实现角度来看,整个过程其实就是在标准自注意力基础上加了一步

masked_fill_

操作。但正是这简单的一步,让模型具备了真正的文本生成能力。相比之下,BERT等双向模型虽然在理解任务上表现出色,但在生成任务上就显得力不从心。

掌握了掩码注意力,你就理解了GPT、LLaMA等主流生成模型的核心工作原理。下次看到这些模型的论文或代码时,相信你会有更深刻的认识。

https://avoid.overfit.cn/post/1eaccf4c67f74b27839e3c5b2372f23c

作者:VIGNESHWARAN

目录
相关文章
|
8天前
|
弹性计算 关系型数据库 微服务
基于 Docker 与 Kubernetes(K3s)的微服务:阿里云生产环境扩容实践
在微服务架构中,如何实现“稳定扩容”与“成本可控”是企业面临的核心挑战。本文结合 Python FastAPI 微服务实战,详解如何基于阿里云基础设施,利用 Docker 封装服务、K3s 实现容器编排,构建生产级微服务架构。内容涵盖容器构建、集群部署、自动扩缩容、可观测性等关键环节,适配阿里云资源特性与服务生态,助力企业打造低成本、高可靠、易扩展的微服务解决方案。
1194 4
|
7天前
|
机器学习/深度学习 人工智能 前端开发
通义DeepResearch全面开源!同步分享可落地的高阶Agent构建方法论
通义研究团队开源发布通义 DeepResearch —— 首个在性能上可与 OpenAI DeepResearch 相媲美、并在多项权威基准测试中取得领先表现的全开源 Web Agent。
950 12
|
6天前
|
机器学习/深度学习 物联网
Wan2.2再次开源数字人:Animate-14B!一键实现电影角色替换和动作驱动
今天,通义万相的视频生成模型又又又开源了!Wan2.2系列模型家族新增数字人成员Wan2.2-Animate-14B。
536 11
|
17天前
|
人工智能 运维 安全
|
8天前
|
弹性计算 Kubernetes jenkins
如何在 ECS/EKS 集群中有效使用 Jenkins
本文探讨了如何将 Jenkins 与 AWS ECS 和 EKS 集群集成,以构建高效、灵活且具备自动扩缩容能力的 CI/CD 流水线,提升软件交付效率并优化资源成本。
341 0
|
8天前
|
消息中间件 Java Apache
SpringBoot集成RocketMq
RocketMQ 是一款开源的分布式消息中间件,采用纯 Java 编写,支持事务消息、顺序消息、批量消息、定时消息及消息回溯等功能。其优势包括去除对 ZooKeeper 的依赖、支持异步和同步刷盘、高吞吐量及消息过滤等特性。RocketMQ 具备高可用性和高可靠性,适用于大规模分布式系统,能有效保障消息传输的一致性和顺序性。
463 2
|
15天前
|
人工智能 异构计算
敬请锁定《C位面对面》,洞察通用计算如何在AI时代持续赋能企业创新,助力业务发展!
敬请锁定《C位面对面》,洞察通用计算如何在AI时代持续赋能企业创新,助力业务发展!
|
8天前
|
云栖大会
阿里云云栖大会2025年9月24日开启,免费申请大会门票,速度领取~
2025云栖大会将于9月24-26日举行,官网免费预约畅享票,审核后短信通知,持证件入场
1566 12