# PyTorch快餐教程2019 (2) - Multi-Head Attention

## Scaled Dot-Product Attention

$Attention(Q,K,V)=softmax(frac{QK^T}{sqrt{d_k}})V$

Q乘以K的转置，再除以$d_k$的平方根进行缩放，经过一个可选的Mask，经过softmax之后，再与V相乘。

def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1)
scores = torch.matmul(query, key.transpose(-2, -1)) \
/ math.sqrt(d_k)
p_attn = F.softmax(scores, dim = -1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn

$MultiHead(Q,K,V)=concat(head_1,...,head_n)W^O$

$W_i^Qin mathbb{R}^{d_{model} times d_k}, W_i^Kinmathbb{R}^{d_{model} times d_k}, W_i^Vinmathbb{R}^{d_{model} times d_v}, W_oinmathbb{R}^{hd_v times d_{model}}$
$d_k=d_v=d_{model}/h=64$

class MultiHeadedAttention(nn.Module):
def __init__(self, h, d_model, dropout=0.1):
"初始化时指定头数h和模型维度d_model"
# 二者是一定整除的
assert d_model % h == 0
# 按照文中的简化，我们让d_v与d_k相等
self.d_k = d_model // h
self.h = h
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(p=dropout)

def clones(module, N):
"生成n个相同的层"
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

    def forward(self, query, key, value, mask=None):
"实现多头注意力模型"
nbatches = query.size(0)

        query, key, value = \
[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for l, x in zip(self.linears, (query, key, value))]

        x, self.attn = attention(query, key, value, mask=mask,
dropout=self.dropout)

        x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
return self.linears[-1](x)

## 再看一种写法巩固一下

class SelfAttention(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()

self.hid_dim = hid_dim

# d_model // h 仍然是要能整除，换个名字仍然意义不变
assert hid_dim % n_heads == 0

self.w_q = nn.Linear(hid_dim, hid_dim)
self.w_k = nn.Linear(hid_dim, hid_dim)
self.w_v = nn.Linear(hid_dim, hid_dim)

self.fc = nn.Linear(hid_dim, hid_dim)

self.do = nn.Dropout(dropout)

self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)

    def forward(self, query, key, value, mask=None):

# Q,K,V计算与变形：

bsz = query.shape[0]

Q = self.w_q(query)
K = self.w_k(key)
V = self.w_v(value)

Q = Q.view(bsz, -1, self.n_heads, self.hid_dim //
K = K.view(bsz, -1, self.n_heads, self.hid_dim //
V = V.view(bsz, -1, self.n_heads, self.hid_dim //

# Q, K相乘除以scale，这是计算scaled dot product attention的第一步

energy = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

# 然后对Q,K相乘的结果计算softmax加上dropout，这是计算scaled dot product attention的第二步：

attention = self.do(torch.softmax(energy, dim=-1))

# 第三步，attention结果与V相乘

x = torch.matmul(attention, V)

x = x.permute(0, 2, 1, 3).contiguous()

x = self.fc(x)

return x

+ 订阅