Bert Pytorch 源码分析:二、注意力层

简介: Bert Pytorch 源码分析:二、注意力层
# 注意力机制的具体模块
# 兼容单头和多头
class Attention(nn.Module):
    """
    Compute 'Scaled Dot Product Attention
    """
  # QKV 尺寸都是 BS * ML * ES
  # (或者多头情况下是 BS * HC * ML * HS,最后两维之外的维度不重要)
  # 从输入计算 QKV 的过程可以统一处理,不必放到每个头里面
    def forward(self, query, key, value, mask=None, dropout=None):
    # 将每个批量的 Q 和 K.T 做矩阵乘法,再除以√ES,
    # 得到相关性矩阵 S,尺寸为 BS * ML * ML
        scores = torch.matmul(query, key.transpose(-2, -1)) \
                 / math.sqrt(query.size(-1))
    # 如果存在掩码则使用它
    # 将 scores 的 mask == 0 的位置上的元素改为 -1e9
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        # 将 S 转换到概率空间,同时对其最后一维归一化
        p_attn = F.softmax(scores, dim=-1)
    # 如果存在 dropout 则使用
        if dropout is not None:
            p_attn = dropout(p_attn)
    # 最后将 S 与 V 相乘得到输出
        return torch.matmul(p_attn, value), p_attn
# 多头注意力就是包含很多(HC)个头,但是每个头的尺寸(HS)变为原来的 1/HC
# 把 qkv 切成小段分给每个头做运算,将结果拼起来作为整个层的输出
class MultiHeadedAttention(nn.Module):
    """
    Take in model size and number of heads.
    """
  # h 是头数(HC)
  # d_model 是嵌入向量大小(ES)
    def __init__(self, h, d_model, dropout=0.1):
        super().__init__()
    # 判断 ES 是否能被 HC 整除,以便结果能拼接回去
        assert d_model % h == 0
    # d_k 是每个头的大小 HS = ES // HC
        self.d_k = d_model // h
        self.h = h
    # 创建输入转换为QKV的权重矩阵,Wq, Wk, Wv,尺寸均为 ES * ES
        self.linear_layers = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(3)])
    # 输出应该还乘一个权重矩阵,Wo,尺寸也是 ES * ES
        self.output_linear = nn.Linear(d_model, d_model)
    # 创建执行注意力机制的具体模块
        self.attention = Attention()
    # 创建 droput 层
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, query, key, value, mask=None):
    # 获取批量大小(BS)
        batch_size = query.size(0)
    '''
        query, key, value = [
      l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        for l, x in zip(self.linear_layers, (query, key, value))
    ]
    '''
    # 将 QKV 的每个与其相应权重矩阵 Wq, Wk, Wv 相乘
    lq, lk, lv = self.linear_layers
    query, key, value = lq(query), lk(key), lv(value) 
    # 然后将他们转型为 BS * ML * HC * HS
    # 也就是将最后一个维度按头部数量分割成小的向量
    query, key, value = [
      x.view(batch_size, -1, self.h, self.d_k)
      for x in (query, key, value)
    ]
    # 然后交换 1 和 2 维,变成 BS * HC * ML  * HS
    # 这样每个头的 QKV 是内存连续的,便于矩阵相乘
    query, key, value = [
      x.transpose(1, 2)
      for x in (query, key, value)
    ]
        # 对每个头应用注意力机制,输出尺寸不变
        x, attn = self.attention(query, key, value, mask=mask, dropout=self.dropout)
        # 交换 1 和 2 维恢复原状,然后把每个头的输出相连接,尺寸变为 BS * ML * ES
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
    # 执行最后的矩阵相乘
        return self.output_linear(x)

缩写表

  • BS:批量大小,即一批数据中样本大小,训练集和测试集可能不同,那就是TBS和VBS
  • ES:嵌入大小,嵌入向量空间的维数,也是注意力层的隐藏单元数量,GPT 中一般是 768
  • ML:输入序列最大长度,一般是512或者1024,不够需要用<pad>填充
  • HC:头部的数量,需要能够整除ES,因为每个头的输出拼接起来才是层的输出
  • HS:头部大小,等于ES // HC
  • VS:词汇表大小,也就是词的种类数量

尺寸备注

  • 嵌入层的矩阵尺寸应该是VS * ES
  • 注意力层的输入尺寸是BS * ML * ES
  • 输出以及 Q K V 和输入形状相同
  • 每个头的 QKV 尺寸为BS * ML * HS
  • 权重矩阵尺寸为ES * ES
  • 相关矩阵 S 尺寸为BS * ML * ML
相关文章
|
3月前
|
机器学习/深度学习 关系型数据库 MySQL
大模型中常用的注意力机制GQA详解以及Pytorch代码实现
GQA是一种结合MQA和MHA优点的注意力机制,旨在保持MQA的速度并提供MHA的精度。它将查询头分成组,每组共享键和值。通过Pytorch和einops库,可以简洁实现这一概念。GQA在保持高效性的同时接近MHA的性能,是高负载系统优化的有力工具。相关论文和非官方Pytorch实现可进一步探究。
554 4
|
3月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图 REV1
Bert Pytorch 源码分析:五、模型架构简图 REV1
57 0
|
3月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:四、编解码器
Bert Pytorch 源码分析:四、编解码器
71 0
|
3月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:三、Transformer块
Bert Pytorch 源码分析:三、Transformer块
60 0
|
3月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图
Bert Pytorch 源码分析:五、模型架构简图
51 0
|
1月前
|
机器学习/深度学习 数据采集 自然语言处理
注意力机制中三种掩码技术详解和Pytorch实现
**注意力机制中的掩码在深度学习中至关重要,如Transformer模型所用。掩码类型包括:填充掩码(忽略填充数据)、序列掩码(控制信息流)和前瞻掩码(自回归模型防止窥视未来信息)。通过创建不同掩码,如上三角矩阵,模型能正确处理变长序列并保持序列依赖性。在注意力计算中,掩码修改得分,确保模型学习的有效性。这些技术在现代NLP和序列任务中是核心组件。**
72 12
|
3月前
|
机器学习/深度学习 自然语言处理 PyTorch
Pytorch图像处理注意力机制SENet CBAM ECA模块解读
注意力机制最初是为了解决自然语言处理(NLP)任务中的问题而提出的,它使得模型能够在处理序列数据时动态地关注不同位置的信息。随后,注意力机制被引入到图像处理任务中,为深度学习模型提供了更加灵活和有效的信息提取能力。注意力机制的核心思想是根据输入数据的不同部分,动态地调整模型的注意力,从而更加关注对当前任务有用的信息。
248 0
|
6天前
|
机器学习/深度学习 人工智能 PyTorch
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
AI智能体研发之路-模型篇(五):pytorch vs tensorflow框架DNN网络结构源码级对比
20 1
|
1月前
|
机器学习/深度学习 算法 PyTorch
使用Pytorch中从头实现去噪扩散概率模型(DDPM)
在本文中,我们将构建基础的无条件扩散模型,即去噪扩散概率模型(DDPM)。从探究算法的直观工作原理开始,然后在PyTorch中从头构建它。本文主要关注算法背后的思想和具体实现细节。
8619 3
|
21天前
|
机器学习/深度学习 人工智能 PyTorch
人工智能平台PAI使用问题之如何布置一个PyTorch的模型
阿里云人工智能平台PAI是一个功能强大、易于使用的AI开发平台,旨在降低AI开发门槛,加速创新,助力企业和开发者高效构建、部署和管理人工智能应用。其中包含了一系列相互协同的产品与服务,共同构成一个完整的人工智能开发与应用生态系统。以下是对PAI产品使用合集的概述,涵盖数据处理、模型开发、训练加速、模型部署及管理等多个环节。