Bert Pytorch 源码分析:三、Transformer块

简介: Bert Pytorch 源码分析:三、Transformer块
# PFF 层,基本相当于两个全连接
# 每个 TF 块中位于注意力层之后
class PositionwiseFeedForward(nn.Module):
    "Implements FFN equation."
    def __init__(self, d_model, d_ff, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
    # LL1,权重矩阵尺寸 ES * FF 
        self.w_1 = nn.Linear(d_model, d_ff)
    # LL2,权重矩阵尺寸 FF * ES
        self.w_2 = nn.Linear(d_ff, d_model)
    # Dropout
        self.dropout = nn.Dropout(dropout)
    # 激活函数是 GELU
        self.activation = GELU()
    def forward(self, x):
    # 输入 -> LL1 -> GELU -> Dropout -> LL2 -> 输出
        return self.w_2(self.dropout(self.activation(self.w_1(x))))
# 处理 TF 块内的残差
class SublayerConnection(nn.Module):
    """
    A residual connection followed by a layer norm.
    Note for code simplicity the norm is first as opposed to last.
    """
    def __init__(self, size, dropout):
        super(SublayerConnection, self).__init__()
    # 层级标准化
        self.norm = LayerNorm(size)
    # Dropout
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, sublayer):
        # 输入 -> LN -> 自定义层 -> Dropout -> 残差连接 -> 输出
    #  |                                    ⬆
    #  +------------------------------------+
        return x + self.dropout(sublayer(self.norm(x)))
# GELU 是 RELU 的高斯平滑近似形式
class GELU(nn.Module):
    """
    Paper Section 3.4, last paragraph notice that BERT used the GELU instead of RELU
    """
    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
# 层级标准化(原理见参考文献)
class LayerNorm(nn.Module):
    "Construct a layernorm module (See citation for details)."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
    # 比例参数
        self.a_2 = nn.Parameter(torch.ones(features))
    # 偏移参数
        self.b_2 = nn.Parameter(torch.zeros(features))
    # 微小值防止除零错误
        self.eps = eps
    def forward(self, x):
    # 均值和方差都是对最后一维,也就是嵌入向量计算的
    # `keepdim=True`保持维数不变
    # 输入尺寸为 BS * ML * ES,计算之后是 BS * ML * 1
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
    # 将最后一维标准化,然后乘以比例加上偏移
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
# Transformer 块是任何 Transformer 架构的基本结构,不仅限于 BERT,
# 不同模型只是层数、头数、嵌入维度、词表、训练数据以及解码器(具体任务)不同
class TransformerBlock(nn.Module):
    """
    Bidirectional Encoder = Transformer (self-attention)
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """
    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        """
        :param hidden: hidden size of transformer
        :param attn_heads: head sizes of multi-head attention
        :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
        :param dropout: dropout rate
        """
        super().__init__()
    # 第一部分:注意力层
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
    # 第二部分:PFF 层
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
    # 注意力层残差模块
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
    # PFF 层的残差模块
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
    # 最后的 Dropout
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, x, mask):
    # 输入 -> LN1 -> 注意力层 -> DropOut1 -> 残差连接 -> ...
    #  |                                      ↑
    #  +--------------------------------------+
    # 这里的注意力层的三个输入全是`x`,但是仍然命名为 QKV,容易引起混淆
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
    # ... -> LN2 -> FFN -> DropOut2 -> 残差连接 -> ...
    #  |                                  ↑
    #  +----------------------------------+
        x = self.output_sublayer(x, self.feed_forward)
    # ... -> DropOut3 -> 结果
        return self.dropout(x)
相关文章
|
4天前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图 REV1
Bert Pytorch 源码分析:五、模型架构简图 REV1
37 0
|
4天前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:四、编解码器
Bert Pytorch 源码分析:四、编解码器
43 0
|
4天前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图
Bert Pytorch 源码分析:五、模型架构简图
32 0
|
4天前
|
机器学习/深度学习 自然语言处理 PyTorch
使用Transformer 模型进行时间序列预测的Pytorch代码示例
时间序列预测是一个经久不衰的主题,受自然语言处理领域的成功启发,transformer模型也在时间序列预测有了很大的发展。本文可以作为学习使用Transformer 模型的时间序列预测的一个起点。
317 2
|
4天前
|
机器学习/深度学习 编解码 PyTorch
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
Pytorch实现手写数字识别 | MNIST数据集(CNN卷积神经网络)
|
2天前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
9 1
|
3天前
|
机器学习/深度学习 PyTorch 算法框架/工具
神经网络基本概念以及Pytorch实现,多线程编程面试题
神经网络基本概念以及Pytorch实现,多线程编程面试题
|
4天前
|
机器学习/深度学习 数据采集 PyTorch
构建你的第一个PyTorch神经网络模型
【4月更文挑战第17天】本文介绍了如何使用PyTorch构建和训练第一个神经网络模型。首先,准备数据集,如MNIST。接着,自定义神经网络模型`SimpleNet`,包含两个全连接层和ReLU激活函数。然后,定义交叉熵损失函数和SGD优化器。训练模型涉及多次迭代,计算损失、反向传播和参数更新。最后,测试模型性能,计算测试集上的准确率。这是一个基础的深度学习入门示例,为进一步探索复杂项目打下基础。
|
4天前
|
机器学习/深度学习 PyTorch 算法框架/工具
Python中用PyTorch机器学习神经网络分类预测银行客户流失模型
Python中用PyTorch机器学习神经网络分类预测银行客户流失模型
|
4天前
|
机器学习/深度学习 数据可视化 PyTorch
时空图神经网络ST-GNN的概念以及Pytorch实现
本文介绍了图神经网络(GNN)在处理各种领域中相互关联的图数据时的作用,如分子结构和社交网络。GNN与序列模型(如RNN)结合形成的时空图神经网络(ST-GNN)能捕捉时间和空间依赖性。文章通过图示和代码示例解释了GNN和ST-GNN的基本原理,展示了如何将GNN应用于股票市场的数据,尽管不推荐将其用于实际的股市预测。提供的PyTorch实现展示了如何将时间序列数据转换为图结构并训练ST-GNN模型。
31 1

相关实验场景

更多