# 注意力机制的具体模块
# 兼容单头和多头
class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""
# QKV 尺寸都是 BS * ML * ES
# (或者多头情况下是 BS * HC * ML * HS,最后两维之外的维度不重要)
# 从输入计算 QKV 的过程可以统一处理,不必放到每个头里面
def forward(self, query, key, value, mask=None, dropout=None):
# 将每个批量的 Q 和 K.T 做矩阵乘法,再除以√ES,
# 得到相关性矩阵 S,尺寸为 BS * ML * ML
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(query.size(-1))
# 如果存在掩码则使用它
# 将 scores 的 mask == 0 的位置上的元素改为 -1e9
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 将 S 转换到概率空间,同时对其最后一维归一化
p_attn = F.softmax(scores, dim=-1)
# 如果存在 dropout 则使用
if dropout is not None:
p_attn = dropout(p_attn)
# 最后将 S 与 V 相乘得到输出
return torch.matmul(p_attn, value), p_attn
# 多头注意力就是包含很多(HC)个头,但是每个头的尺寸(HS)变为原来的 1/HC
# 把 qkv 切成小段分给每个头做运算,将结果拼起来作为整个层的输出
class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""
# h 是头数(HC)
# d_model 是嵌入向量大小(ES)
def __init__(self, h, d_model, dropout=0.1):
super().__init__()
# 判断 ES 是否能被 HC 整除,以便结果能拼接回去
assert d_model % h == 0
# d_k 是每个头的大小 HS = ES // HC
self.d_k = d_model // h
self.h = h
# 创建输入转换为QKV的权重矩阵,Wq, Wk, Wv,尺寸均为 ES * ES
self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
# 输出应该还乘一个权重矩阵,Wo,尺寸也是 ES * ES
self.output_linear = nn.Linear(d_model, d_model)
# 创建执行注意力机制的具体模块
self.attention = Attention()
# 创建 droput 层
self.dropout = nn.Dropout(p=dropout)
def forward(self, query, key, value, mask=None):
# 获取批量大小(BS)
batch_size = query.size(0)
'''
query, key, value = [
l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linear_layers, (query, key, value))
]
'''
# 将 QKV 的每个与其相应权重矩阵 Wq, Wk, Wv 相乘
lq, lk, lv = self.linear_layers
query, key, value = lq(query), lk(key), lv(value)
# 然后将他们转型为 BS * ML * HC * HS
# 也就是将最后一个维度按头部数量分割成小的向量
query, key, value = [
x.view(batch_size, -1, self.h, self.d_k)
for x in (query, key, value)
]
# 然后交换 1 和 2 维,变成 BS * HC * ML * HS
# 这样每个头的 QKV 是内存连续的,便于矩阵相乘
query, key, value = [
x.transpose(1, 2)
for x in (query, key, value)
]
# 对每个头应用注意力机制,输出尺寸不变
x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
# 交换 1 和 2 维恢复原状,然后把每个头的输出相连接,尺寸变为 BS * ML * ES
x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
# 执行最后的矩阵相乘
return self.output_linear(x)