Self-Attention 原理与代码实现

简介: Self-Attention 原理与代码实现

Self-Attention 是一种用于处理序列数据的机制,最初在 Transformer 模型中广泛使用。它允许模型在处理序列数据时,对序列中不同位置的元素进行加权聚合,从而更好地捕捉元素之间的依赖关系和全局信息。

 

### Self-Attention 的原理:

 

1. **Query、Key、Value**:

  - 对于输入序列中的每个元素,通过三个线性变换(分别是 Query 矩阵、Key 矩阵和 Value 矩阵)将输入向量映射到三个不同的表示空间。这些映射可以通过学习得到,通常是通过权重矩阵乘法实现的。

 

2. **计算注意力权重**:

  - 接下来,计算 Query 和每个 Key 之间的相关性得分,一般使用点积(Dot Product)计算 Query 和 Key 的相似度。然后将这些得分进行缩放(通常除以特征维度的平方根)并应用 Softmax 函数,得到注意力权重,表示了每个元素与其他元素的重要性。

 

3. **加权求和**:

  - 最后,将每个 Value 向量乘以对应的注意力权重,然后将所有加权后的 Value 向量相加,得到最终的输出表示。这个输出表示包含了所有元素的信息,且每个元素的权重由注意力机制决定。

 

4. **多头注意力**:

  - 为了增强模型的表达能力,通常会使用多头注意力(Multi-Head Attention),即同时学习多组不同的 Query、Key、Value 矩阵,最后将它们拼接并再次进行线性变换得到最终输出。

 

Self-Attention 的优点在于可以捕捉长距离依赖关系,同时允许模型在不同位置之间建立直接的联系,而无需像循环神经网络(RNN)那样依赖序列的顺序。这使得 Self-Attention 在处理长序列和并行计算方面具有优势,因此在自然语言处理等领域得到了广泛应用。

 

以下是一个简单的 Self-Attention 的 PyTorch 实现示例:

 

```python
import torch
import torch.nn.functional as F
 
class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim, heads):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.heads = heads
        self.head_dim = input_dim // heads
 
        self.W_q = torch.nn.Linear(input_dim, input_dim)
        self.W_k = torch.nn.Linear(input_dim, input_dim)
        self.W_v = torch.nn.Linear(input_dim, input_dim)
 
        self.W_o = torch.nn.Linear(input_dim, input_dim)
 
    def forward(self, x):
        batch_size = x.shape[0]
 
        # Linear transformation to get Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
 
        # Reshape Q, K, V to have multiple heads
        Q = Q.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
 
        # Compute attention scores
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
 
        # Apply attention weights to V
        attention_output = torch.matmul(attention_weights, V)
 
        # Reshape and concatenate heads
        attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
        attention_output = attention_output.view(batch_size, -1, self.input_dim)
 
        # Final linear transformation
        output = self.W_o(attention_output)
 
        return output
 
# 使用示例
input_dim = 64
seq_length = 10
heads = 8
input_data = torch.randn(1, seq_length, input_dim)  # 生成随机输入数据
self_attention = SelfAttention(input_dim, heads)
output = self_attention(input_data)
print(output.shape)  # 输出形状:torch.Size([1, 10, 64])
```

 

在这个示例中,定义了一个包含多头注意力机制的 SelfAttention 类。通过传入输入数据,可以得到经过 Self-Attention 处理后的输出。请注意,实际应用中可能会有更复杂的 Self-Attention 变体和改进。

相关文章
|
机器学习/深度学习 人工智能 自然语言处理
视觉 注意力机制——通道注意力、空间注意力、自注意力
本文介绍注意力机制的概念和基本原理,并站在计算机视觉CV角度,进一步介绍通道注意力、空间注意力、混合注意力、自注意力等。
12912 58
|
机器学习/深度学习 人工智能 自然语言处理
一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
|
8月前
|
人工智能 边缘计算 算法
DistilQwen2.5-R1发布:知识蒸馏助推小模型深度思考
DistilQwen2.5-R1通过知识蒸馏技术,将大规模深度推理模型的知识迁移到小模型中,显著提升了小模型的推理能力。实验结果表明,DistilQwen2.5-R1在数学、代码和科学问题等多个基准测试中表现优异,尤其在7B参数量级上超越了其他开源蒸馏模型。 本文将深入阐述 DistilQwen2.5-R1 的蒸馏算法、性能评估,并且提供在阿里云人工智能平台 PAI 上的使用指南及相关下载教程。
|
4月前
|
人工智能 Nacos 开发者
手把手教你搭建MCP服务器
Model Context Protocol(MCP)正成为AI智能体连接外部工具的主流标准。本文详解两种搭建方案,助你构建专属AI工具扩展引擎,实现工具调用的标准化与高效集成。
|
机器学习/深度学习 人工智能 自然语言处理
LLM 大模型学习必知必会系列(一):大模型基础知识篇
LLM 大模型学习必知必会系列(一):大模型基础知识篇
LLM 大模型学习必知必会系列(一):大模型基础知识篇
|
文字识别 自然语言处理 数据可视化
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!
在 Qwen2 发布后的过去三个月里,许多开发者基于 Qwen2 语言模型构建了新的模型,并提供了宝贵的反馈。在这段时间里,通义千问团队专注于创建更智能、更博学的语言模型。今天,Qwen 家族的最新成员:Qwen2.5系列正式开源
Qwen2.5 全链路模型体验、下载、推理、微调、部署实战!
|
缓存 PyTorch API
Transformers 4.37 中文文档(四十)(2)
Transformers 4.37 中文文档(四十)
365 1
|
11月前
|
数据采集 前端开发 物联网
【项目实战】通过LLaMaFactory+Qwen2-VL-2B微调一个多模态医疗大模型
本文介绍了一个基于多模态大模型的医疗图像诊断项目。项目旨在通过训练一个医疗领域的多模态大模型,提高医生处理医学图像的效率,辅助诊断和治疗。作者以家中老人的脑部CT为例,展示了如何利用MedTrinity-25M数据集训练模型,经过数据准备、环境搭建、模型训练及微调、最终验证等步骤,成功使模型能够识别CT图像并给出具体的诊断意见,与专业医生的诊断结果高度吻合。
21324 162
【项目实战】通过LLaMaFactory+Qwen2-VL-2B微调一个多模态医疗大模型
|
人工智能 边缘计算 自然语言处理
DistilQwen2:通义千问大模型的知识蒸馏实践
DistilQwen2 是基于 Qwen2大模型,通过知识蒸馏进行指令遵循效果增强的、参数较小的语言模型。本文将介绍DistilQwen2 的技术原理、效果评测,以及DistilQwen2 在阿里云人工智能平台 PAI 上的使用方法,和在各开源社区的下载使用教程。