Self-Attention 原理与代码实现

简介: Self-Attention 原理与代码实现

Self-Attention 是一种用于处理序列数据的机制,最初在 Transformer 模型中广泛使用。它允许模型在处理序列数据时,对序列中不同位置的元素进行加权聚合,从而更好地捕捉元素之间的依赖关系和全局信息。

 

### Self-Attention 的原理:

 

1. **Query、Key、Value**:

  - 对于输入序列中的每个元素,通过三个线性变换(分别是 Query 矩阵、Key 矩阵和 Value 矩阵)将输入向量映射到三个不同的表示空间。这些映射可以通过学习得到,通常是通过权重矩阵乘法实现的。

 

2. **计算注意力权重**:

  - 接下来,计算 Query 和每个 Key 之间的相关性得分,一般使用点积(Dot Product)计算 Query 和 Key 的相似度。然后将这些得分进行缩放(通常除以特征维度的平方根)并应用 Softmax 函数,得到注意力权重,表示了每个元素与其他元素的重要性。

 

3. **加权求和**:

  - 最后,将每个 Value 向量乘以对应的注意力权重,然后将所有加权后的 Value 向量相加,得到最终的输出表示。这个输出表示包含了所有元素的信息,且每个元素的权重由注意力机制决定。

 

4. **多头注意力**:

  - 为了增强模型的表达能力,通常会使用多头注意力(Multi-Head Attention),即同时学习多组不同的 Query、Key、Value 矩阵,最后将它们拼接并再次进行线性变换得到最终输出。

 

Self-Attention 的优点在于可以捕捉长距离依赖关系,同时允许模型在不同位置之间建立直接的联系,而无需像循环神经网络(RNN)那样依赖序列的顺序。这使得 Self-Attention 在处理长序列和并行计算方面具有优势,因此在自然语言处理等领域得到了广泛应用。

 

以下是一个简单的 Self-Attention 的 PyTorch 实现示例:

 

```python
import torch
import torch.nn.functional as F
 
class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim, heads):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.heads = heads
        self.head_dim = input_dim // heads
 
        self.W_q = torch.nn.Linear(input_dim, input_dim)
        self.W_k = torch.nn.Linear(input_dim, input_dim)
        self.W_v = torch.nn.Linear(input_dim, input_dim)
 
        self.W_o = torch.nn.Linear(input_dim, input_dim)
 
    def forward(self, x):
        batch_size = x.shape[0]
 
        # Linear transformation to get Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
 
        # Reshape Q, K, V to have multiple heads
        Q = Q.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
 
        # Compute attention scores
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
 
        # Apply attention weights to V
        attention_output = torch.matmul(attention_weights, V)
 
        # Reshape and concatenate heads
        attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
        attention_output = attention_output.view(batch_size, -1, self.input_dim)
 
        # Final linear transformation
        output = self.W_o(attention_output)
 
        return output
 
# 使用示例
input_dim = 64
seq_length = 10
heads = 8
input_data = torch.randn(1, seq_length, input_dim)  # 生成随机输入数据
self_attention = SelfAttention(input_dim, heads)
output = self_attention(input_data)
print(output.shape)  # 输出形状:torch.Size([1, 10, 64])
```

 

在这个示例中,定义了一个包含多头注意力机制的 SelfAttention 类。通过传入输入数据,可以得到经过 Self-Attention 处理后的输出。请注意,实际应用中可能会有更复杂的 Self-Attention 变体和改进。

相关文章
|
人工智能 边缘计算 算法
DistilQwen2.5-R1发布:知识蒸馏助推小模型深度思考
DistilQwen2.5-R1通过知识蒸馏技术,将大规模深度推理模型的知识迁移到小模型中,显著提升了小模型的推理能力。实验结果表明,DistilQwen2.5-R1在数学、代码和科学问题等多个基准测试中表现优异,尤其在7B参数量级上超越了其他开源蒸馏模型。 本文将深入阐述 DistilQwen2.5-R1 的蒸馏算法、性能评估,并且提供在阿里云人工智能平台 PAI 上的使用指南及相关下载教程。
|
10月前
|
机器学习/深度学习 人工智能 算法
GSPO:Qwen让大模型强化学习训练告别崩溃,解决序列级强化学习中的稳定性问题
这是7月份的一篇论文,Qwen团队提出的群组序列策略优化算法及其在大规模语言模型强化学习训练中的技术突破
1931 0
GSPO:Qwen让大模型强化学习训练告别崩溃,解决序列级强化学习中的稳定性问题
|
SQL 缓存 监控
后端接口性能优化分析-问题发现&问题定义(上)
后端接口性能优化分析-问题发现&问题定义
795 0
|
存储 编解码 数据可视化
云VR:虚拟现实专业化的下一步
但究竟什么是云VR,云VR将如何帮助各行各业开展业务?本文将带您了解VR和云VR的未来,以及它与我们目前可以体验的沉浸式VR有何不同。
|
机器学习/深度学习 数据采集 存储
大模型微调知识与实践分享
本文详细介绍了大型语言模型(LLM)的结构、参数量、显存占用、存储需求以及微调过程中的关键技术点,包括Prompt工程、数据构造、LoRA微调方法等。
2987 72
大模型微调知识与实践分享
|
机器学习/深度学习 算法
广义优势估计(GAE):端策略优化PPO中偏差与方差平衡的关键技术
广义优势估计(GAE)由Schulman等人于2016年提出,是近端策略优化(PPO)算法的核心理论基础。它通过平衡偏差与方差,解决了强化学习中的信用分配问题,即如何准确判定历史动作对延迟奖励的贡献。GAE基于资格迹和TD-λ思想,采用n步优势的指数加权平均方法,将优势函数有效集成到损失函数中,为策略优化提供稳定梯度信号。相比TD-λ,GAE更适用于现代策略梯度方法,推动了高效强化学习算法的发展。
2475 3
广义优势估计(GAE):端策略优化PPO中偏差与方差平衡的关键技术
|
11月前
|
人工智能 Nacos 开发者
手把手教你搭建MCP服务器
Model Context Protocol(MCP)正成为AI智能体连接外部工具的主流标准。本文详解两种搭建方案,助你构建专属AI工具扩展引擎,实现工具调用的标准化与高效集成。
|
机器学习/深度学习 自然语言处理 PyTorch
深入剖析Transformer架构中的多头注意力机制
多头注意力机制(Multi-Head Attention)是Transformer模型中的核心组件,通过并行运行多个独立的注意力机制,捕捉输入序列中不同子空间的语义关联。每个“头”独立处理Query、Key和Value矩阵,经过缩放点积注意力运算后,所有头的输出被拼接并通过线性层融合,最终生成更全面的表示。多头注意力不仅增强了模型对复杂依赖关系的理解,还在自然语言处理任务如机器翻译和阅读理解中表现出色。通过多头自注意力机制,模型在同一序列内部进行多角度的注意力计算,进一步提升了表达能力和泛化性能。
11087 48
|
文字识别 自然语言处理 数据可视化
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!
在 Qwen2 发布后的过去三个月里,许多开发者基于 Qwen2 语言模型构建了新的模型,并提供了宝贵的反馈。在这段时间里,通义千问团队专注于创建更智能、更博学的语言模型。今天,Qwen 家族的最新成员:Qwen2.5系列正式开源
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!

热门文章

最新文章