Self-Attention 是一种用于处理序列数据的机制,最初在 Transformer 模型中广泛使用。它允许模型在处理序列数据时,对序列中不同位置的元素进行加权聚合,从而更好地捕捉元素之间的依赖关系和全局信息。
### Self-Attention 的原理:
1. **Query、Key、Value**:
- 对于输入序列中的每个元素,通过三个线性变换(分别是 Query 矩阵、Key 矩阵和 Value 矩阵)将输入向量映射到三个不同的表示空间。这些映射可以通过学习得到,通常是通过权重矩阵乘法实现的。
2. **计算注意力权重**:
- 接下来,计算 Query 和每个 Key 之间的相关性得分,一般使用点积(Dot Product)计算 Query 和 Key 的相似度。然后将这些得分进行缩放(通常除以特征维度的平方根)并应用 Softmax 函数,得到注意力权重,表示了每个元素与其他元素的重要性。
3. **加权求和**:
- 最后,将每个 Value 向量乘以对应的注意力权重,然后将所有加权后的 Value 向量相加,得到最终的输出表示。这个输出表示包含了所有元素的信息,且每个元素的权重由注意力机制决定。
4. **多头注意力**:
- 为了增强模型的表达能力,通常会使用多头注意力(Multi-Head Attention),即同时学习多组不同的 Query、Key、Value 矩阵,最后将它们拼接并再次进行线性变换得到最终输出。
Self-Attention 的优点在于可以捕捉长距离依赖关系,同时允许模型在不同位置之间建立直接的联系,而无需像循环神经网络(RNN)那样依赖序列的顺序。这使得 Self-Attention 在处理长序列和并行计算方面具有优势,因此在自然语言处理等领域得到了广泛应用。
以下是一个简单的 Self-Attention 的 PyTorch 实现示例:
```python import torch import torch.nn.functional as F class SelfAttention(torch.nn.Module): def __init__(self, input_dim, heads): super(SelfAttention, self).__init__() self.input_dim = input_dim self.heads = heads self.head_dim = input_dim // heads self.W_q = torch.nn.Linear(input_dim, input_dim) self.W_k = torch.nn.Linear(input_dim, input_dim) self.W_v = torch.nn.Linear(input_dim, input_dim) self.W_o = torch.nn.Linear(input_dim, input_dim) def forward(self, x): batch_size = x.shape[0] # Linear transformation to get Q, K, V Q = self.W_q(x) K = self.W_k(x) V = self.W_v(x) # Reshape Q, K, V to have multiple heads Q = Q.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3) K = K.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3) V = V.view(batch_size, -1, self.heads, self.head_dim).permute(0, 2, 1, 3) # Compute attention scores scores = torch.matmul(Q, K.permute(0, 1, 3, 2)) / (self.head_dim ** 0.5) attention_weights = F.softmax(scores, dim=-1) # Apply attention weights to V attention_output = torch.matmul(attention_weights, V) # Reshape and concatenate heads attention_output = attention_output.permute(0, 2, 1, 3).contiguous() attention_output = attention_output.view(batch_size, -1, self.input_dim) # Final linear transformation output = self.W_o(attention_output) return output # 使用示例 input_dim = 64 seq_length = 10 heads = 8 input_data = torch.randn(1, seq_length, input_dim) # 生成随机输入数据 self_attention = SelfAttention(input_dim, heads) output = self_attention(input_data) print(output.shape) # 输出形状:torch.Size([1, 10, 64]) ```
在这个示例中,定义了一个包含多头注意力机制的 SelfAttention 类。通过传入输入数据,可以得到经过 Self-Attention 处理后的输出。请注意,实际应用中可能会有更复杂的 Self-Attention 变体和改进。