Self-Attention 原理与代码实现

简介: Self-Attention 原理与代码实现

Self-Attention 是一种用于处理序列数据的机制,最初在 Transformer 模型中广泛使用。它允许模型在处理序列数据时,对序列中不同位置的元素进行加权聚合,从而更好地捕捉元素之间的依赖关系和全局信息。

 

### Self-Attention 的原理:

 

1. **Query、Key、Value**:

  - 对于输入序列中的每个元素,通过三个线性变换(分别是 Query 矩阵、Key 矩阵和 Value 矩阵)将输入向量映射到三个不同的表示空间。这些映射可以通过学习得到,通常是通过权重矩阵乘法实现的。

 

2. **计算注意力权重**:

  - 接下来,计算 Query 和每个 Key 之间的相关性得分,一般使用点积(Dot Product)计算 Query 和 Key 的相似度。然后将这些得分进行缩放(通常除以特征维度的平方根)并应用 Softmax 函数,得到注意力权重,表示了每个元素与其他元素的重要性。

 

3. **加权求和**:

  - 最后,将每个 Value 向量乘以对应的注意力权重,然后将所有加权后的 Value 向量相加,得到最终的输出表示。这个输出表示包含了所有元素的信息,且每个元素的权重由注意力机制决定。

 

4. **多头注意力**:

  - 为了增强模型的表达能力,通常会使用多头注意力(Multi-Head Attention),即同时学习多组不同的 Query、Key、Value 矩阵,最后将它们拼接并再次进行线性变换得到最终输出。

 

Self-Attention 的优点在于可以捕捉长距离依赖关系,同时允许模型在不同位置之间建立直接的联系,而无需像循环神经网络(RNN)那样依赖序列的顺序。这使得 Self-Attention 在处理长序列和并行计算方面具有优势,因此在自然语言处理等领域得到了广泛应用。

 

以下是一个简单的 Self-Attention 的 PyTorch 实现示例:

 

```python
import torch
import torch.nn.functional as F
 
class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim, heads):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.heads = heads
        self.head_dim = input_dim // heads
 
        self.W_q = torch.nn.Linear(input_dim, input_dim)
        self.W_k = torch.nn.Linear(input_dim, input_dim)
        self.W_v = torch.nn.Linear(input_dim, input_dim)
 
        self.W_o = torch.nn.Linear(input_dim, input_dim)
 
    def forward(self, x):
        batch_size = x.shape[0]
 
        # Linear transformation to get Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
 
        # Reshape Q, K, V to have multiple heads
        Q = Q.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
 
        # Compute attention scores
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
 
        # Apply attention weights to V
        attention_output = torch.matmul(attention_weights, V)
 
        # Reshape and concatenate heads
        attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
        attention_output = attention_output.view(batch_size, -1, self.input_dim)
 
        # Final linear transformation
        output = self.W_o(attention_output)
 
        return output
 
# 使用示例
input_dim = 64
seq_length = 10
heads = 8
input_data = torch.randn(1, seq_length, input_dim)  # 生成随机输入数据
self_attention = SelfAttention(input_dim, heads)
output = self_attention(input_data)
print(output.shape)  # 输出形状:torch.Size([1, 10, 64])
```

 

在这个示例中,定义了一个包含多头注意力机制的 SelfAttention 类。通过传入输入数据,可以得到经过 Self-Attention 处理后的输出。请注意,实际应用中可能会有更复杂的 Self-Attention 变体和改进。

相关文章
|
11月前
|
人工智能 边缘计算 算法
DistilQwen2.5-R1发布:知识蒸馏助推小模型深度思考
DistilQwen2.5-R1通过知识蒸馏技术,将大规模深度推理模型的知识迁移到小模型中,显著提升了小模型的推理能力。实验结果表明,DistilQwen2.5-R1在数学、代码和科学问题等多个基准测试中表现优异,尤其在7B参数量级上超越了其他开源蒸馏模型。 本文将深入阐述 DistilQwen2.5-R1 的蒸馏算法、性能评估,并且提供在阿里云人工智能平台 PAI 上的使用指南及相关下载教程。
|
机器学习/深度学习 人工智能 自然语言处理
LLM 大模型学习必知必会系列(一):大模型基础知识篇
LLM 大模型学习必知必会系列(一):大模型基础知识篇
LLM 大模型学习必知必会系列(一):大模型基础知识篇
|
文字识别 自然语言处理 数据可视化
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!
在 Qwen2 发布后的过去三个月里,许多开发者基于 Qwen2 语言模型构建了新的模型,并提供了宝贵的反馈。在这段时间里,通义千问团队专注于创建更智能、更博学的语言模型。今天,Qwen 家族的最新成员:Qwen2.5系列正式开源
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!
|
7月前
|
人工智能 Nacos 开发者
手把手教你搭建MCP服务器
Model Context Protocol(MCP)正成为AI智能体连接外部工具的主流标准。本文详解两种搭建方案,助你构建专属AI工具扩展引擎,实现工具调用的标准化与高效集成。
|
11月前
|
机器学习/深度学习 算法
广义优势估计(GAE):端策略优化PPO中偏差与方差平衡的关键技术
广义优势估计(GAE)由Schulman等人于2016年提出,是近端策略优化(PPO)算法的核心理论基础。它通过平衡偏差与方差,解决了强化学习中的信用分配问题,即如何准确判定历史动作对延迟奖励的贡献。GAE基于资格迹和TD-λ思想,采用n步优势的指数加权平均方法,将优势函数有效集成到损失函数中,为策略优化提供稳定梯度信号。相比TD-λ,GAE更适用于现代策略梯度方法,推动了高效强化学习算法的发展。
1900 3
广义优势估计(GAE):端策略优化PPO中偏差与方差平衡的关键技术
|
机器学习/深度学习 自然语言处理 搜索推荐
承上启下:基于全域漏斗分析的主搜深度统一粗排
两阶段排序(粗排-精排)一开始是因系统性能问题提出的排序框架,因此长期以来粗排的定位一直是精排的退化版本,业内的粗排的优化方向也是持续逼近精排。我们提出以全域成交的hitrate为目标的全新指标,重新审视了召回、粗排和精排的关系,指出了全新的优化方向
94376 3
|
搜索推荐 物联网 PyTorch
Qwen2.5-7B-Instruct Lora 微调
本教程介绍如何基于Transformers和PEFT框架对Qwen2.5-7B-Instruct模型进行LoRA微调。
13424 34
Qwen2.5-7B-Instruct Lora 微调

热门文章

最新文章