Self-Attention 原理与代码实现

简介: Self-Attention 原理与代码实现

Self-Attention 是一种用于处理序列数据的机制,最初在 Transformer 模型中广泛使用。它允许模型在处理序列数据时,对序列中不同位置的元素进行加权聚合,从而更好地捕捉元素之间的依赖关系和全局信息。

 

### Self-Attention 的原理:

 

1. **Query、Key、Value**:

  - 对于输入序列中的每个元素,通过三个线性变换(分别是 Query 矩阵、Key 矩阵和 Value 矩阵)将输入向量映射到三个不同的表示空间。这些映射可以通过学习得到,通常是通过权重矩阵乘法实现的。

 

2. **计算注意力权重**:

  - 接下来,计算 Query 和每个 Key 之间的相关性得分,一般使用点积(Dot Product)计算 Query 和 Key 的相似度。然后将这些得分进行缩放(通常除以特征维度的平方根)并应用 Softmax 函数,得到注意力权重,表示了每个元素与其他元素的重要性。

 

3. **加权求和**:

  - 最后,将每个 Value 向量乘以对应的注意力权重,然后将所有加权后的 Value 向量相加,得到最终的输出表示。这个输出表示包含了所有元素的信息,且每个元素的权重由注意力机制决定。

 

4. **多头注意力**:

  - 为了增强模型的表达能力,通常会使用多头注意力(Multi-Head Attention),即同时学习多组不同的 Query、Key、Value 矩阵,最后将它们拼接并再次进行线性变换得到最终输出。

 

Self-Attention 的优点在于可以捕捉长距离依赖关系,同时允许模型在不同位置之间建立直接的联系,而无需像循环神经网络(RNN)那样依赖序列的顺序。这使得 Self-Attention 在处理长序列和并行计算方面具有优势,因此在自然语言处理等领域得到了广泛应用。

 

以下是一个简单的 Self-Attention 的 PyTorch 实现示例:

 

```python
import torch
import torch.nn.functional as F
 
class SelfAttention(torch.nn.Module):
    def __init__(self, input_dim, heads):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        self.heads = heads
        self.head_dim = input_dim // heads
 
        self.W_q = torch.nn.Linear(input_dim, input_dim)
        self.W_k = torch.nn.Linear(input_dim, input_dim)
        self.W_v = torch.nn.Linear(input_dim, input_dim)
 
        self.W_o = torch.nn.Linear(input_dim, input_dim)
 
    def forward(self, x):
        batch_size = x.shape[0]
 
        # Linear transformation to get Q, K, V
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
 
        # Reshape Q, K, V to have multiple heads
        Q = Q.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
        K = K.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
        V = V.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3)
 
        # Compute attention scores
        scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
 
        # Apply attention weights to V
        attention_output = torch.matmul(attention_weights, V)
 
        # Reshape and concatenate heads
        attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
        attention_output = attention_output.view(batch_size, -1, self.input_dim)
 
        # Final linear transformation
        output = self.W_o(attention_output)
 
        return output
 
# 使用示例
input_dim = 64
seq_length = 10
heads = 8
input_data = torch.randn(1, seq_length, input_dim)  # 生成随机输入数据
self_attention = SelfAttention(input_dim, heads)
output = self_attention(input_data)
print(output.shape)  # 输出形状:torch.Size([1, 10, 64])
```

 

在这个示例中,定义了一个包含多头注意力机制的 SelfAttention 类。通过传入输入数据,可以得到经过 Self-Attention 处理后的输出。请注意,实际应用中可能会有更复杂的 Self-Attention 变体和改进。

相关文章
|
6月前
|
机器学习/深度学习 关系型数据库 MySQL
大模型中常用的注意力机制GQA详解以及Pytorch代码实现
GQA是一种结合MQA和MHA优点的注意力机制,旨在保持MQA的速度并提供MHA的精度。它将查询头分成组,每组共享键和值。通过Pytorch和einops库,可以简洁实现这一概念。GQA在保持高效性的同时接近MHA的性能,是高负载系统优化的有力工具。相关论文和非官方Pytorch实现可进一步探究。
836 4
|
6月前
|
机器学习/深度学习 人工智能 数据可视化
图解Transformer——注意力计算原理
图解Transformer——注意力计算原理
156 0
|
4月前
|
机器学习/深度学习 存储 测试技术
【YOLOv10改进-注意力机制】iRMB: 倒置残差移动块 (论文笔记+引入代码)
YOLOv10专栏介绍了融合CNN与Transformer的iRMB模块,用于轻量级模型设计。iRMB在保持高效的同时结合了局部和全局信息处理,减少了资源消耗,提升了移动端性能。在ImageNet等基准上超越SOTA,且在目标检测等任务中表现优秀。代码示例展示了iRMB的实现细节,包括自注意力机制和卷积操作的整合。更多配置信息见相关链接。
|
5月前
|
机器学习/深度学习 计算机视觉 知识图谱
【YOLOv8改进】STA(Super Token Attention) 超级令牌注意力机制 (论文笔记+引入代码)
该专栏探讨YOLO目标检测的创新改进和实战应用,介绍了使用视觉Transformer的新方法。为解决Transformer在浅层处理局部特征时的冗余问题,提出了超级令牌(Super Tokens)和超级令牌注意力(STA)机制,旨在高效建模全局上下文。通过稀疏关联学习和自注意力处理,STA降低了计算复杂度,提升了全局依赖的捕获效率。由此构建的层次化视觉Transformer在ImageNet-1K、COCO检测和ADE20K语义分割任务上展现出优秀性能。此外,文章提供了YOLOv8中实现STA的代码示例。更多详细信息和配置可在相关链接中找到。
|
6月前
|
机器学习/深度学习 存储 算法
注意力机制(一)(基本思想)
在注意力机制论文 Attention Is All You Need 中最苦恼大家的肯定是K、Q、V三个变量的含义 翻阅了CSDN、知乎大量文章后,我发现没有文章能够带大家对注意力机制建立直观的认识 大部分文章要么没有从初学者的角度出发介绍的是注意力机制上层应用,要么其作者自己也并没有真正理解注意力机制所以讲的不清不楚 所以在看完《动手学深度学习(pytorch版)》、Attention Is All You Need 论文、以及大量文章后,我开始动手写这篇专门为初学者的介绍注意力机制的文章 权
|
6月前
|
机器学习/深度学习 自然语言处理 PyTorch
Vision Transformers的注意力层概念解释和代码实现
2017年推出《Attention is All You Need》以来,transformers 已经成为自然语言处理(NLP)的最新技术。2021年,《An Image is Worth 16x16 Words》,成功地将transformers 用于计算机视觉任务。从那时起,许多基于transformers的计算机视觉体系结构被提出。
58 0
|
机器学习/深度学习 PyTorch 算法框架/工具
SE 注意力模块 原理分析与代码实现
本文介绍SE注意力模块,它是在SENet中提出的,SENet是ImageNet 2017的冠军模型;SE模块常常被用于CV模型中,能较有效提取模型精度,所以给大家介绍一下它的原理,设计思路,代码实现,如何应用在模型中。
1412 0
|
机器学习/深度学习 自然语言处理 数据可视化
图解transformer中的自注意力机制
本文将将介绍注意力的概念从何而来,它是如何工作的以及它的简单的实现。
306 0
|
机器学习/深度学习 PyTorch 算法框架/工具
Dropout的深入理解(基础介绍、模型描述、原理深入、代码实现以及变种)
Dropout的深入理解(基础介绍、模型描述、原理深入、代码实现以及变种)
反向传播+代码实现
反向传播+代码实现
111 0