Bert Pytorch 源码分析:四、编解码器

简介: Bert Pytorch 源码分析:四、编解码器
# Bert 编码器模块
# 由一个嵌入层和 NL 个 TF 层组成
class BERT(nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """
    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """
        super().__init__()
    # 嵌入大小 ES
        self.hidden = hidden
    # TF 层数 NL
        self.n_layers = n_layers
    # 头部数量 HC
        self.attn_heads = attn_heads
        # FFN 层中的隐藏单元数量,记为 FF,一般是 ES 的四倍
        self.feed_forward_hidden = hidden * 4
        # 嵌入层,嵌入矩阵尺寸 VS * ES
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
        # NL 个 TF 层
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
    def forward(self, x, segment_info):
        # 为`<pad>`(ID = 0)设置掩码
    # 尺寸为 BS * 1 * ML * ML,以便与相似性矩阵 S 匹配
    # 在每个 BS 的 ML * ML 矩阵中,`<pad>`标记对应的行为 1,其余为零
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        # 单词 ID 传入嵌入层得到词向量
        x = self.embedding(x, segment_info)
        # 依次传入每个 TF 层,得到编码器输出
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)
        return x
# 解码器结构根据具体任务而定
# 任务一般有三种:(1)序列分类,(2)标记分类,(3)序列生成
# 但一般都是全连接的
# 用于下个句子判断的解码器
# 序列分类任务,输入两个句子,输出一个标签,1表示是相邻句子,0表示不是
class NextSentencePrediction(nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """
    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
    # 将向量压缩到两维, 尺寸为 ES * 2
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)
    def forward(self, x):
    # 输入 -> 取第一个向量 -> LL -> softmax -> 输出
    # 输出相邻句子和非相邻句子的概率
        return self.softmax(self.linear(x[:, 0]))
# 用于完型填空的解码器
# 序列生成任务,输入是带有`<mask>`的句子,输出是完整句子
class MaskedLanguageModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """
    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
    # 将输入压缩到词汇表大小
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)
    def forward(self, x):
    # 输入 -> LL -> softmax -> 输出
    # 输出序列中每个词是词汇表中每个词的概率
        return self.softmax(self.linear(x))
相关文章
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图 REV1
Bert Pytorch 源码分析:五、模型架构简图 REV1
356 0
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:三、Transformer块
Bert Pytorch 源码分析:三、Transformer块
208 0
|
PyTorch 算法框架/工具 C++
Bert Pytorch 源码分析:二、注意力层
Bert Pytorch 源码分析:二、注意力层
317 0
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图
Bert Pytorch 源码分析:五、模型架构简图
278 0
|
11月前
|
PyTorch 调度 算法框架/工具
阿里云PAI-DLC任务Pytorch launch_agent Socket Timeout问题源码分析
DLC任务Pytorch launch_agent Socket Timeout问题源码分析与解决方案
548 18
阿里云PAI-DLC任务Pytorch launch_agent Socket Timeout问题源码分析
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
1064 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
7月前
|
机器学习/深度学习 存储 PyTorch
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
Neural ODE将神经网络与微分方程结合,用连续思维建模数据演化,突破传统离散层的限制,实现自适应深度与高效连续学习。
615 3
Neural ODE原理与PyTorch实现:深度学习模型的自适应深度调节
|
6月前
|
边缘计算 人工智能 PyTorch
130_知识蒸馏技术:温度参数与损失函数设计 - 教师-学生模型的优化策略与PyTorch实现
随着大型语言模型(LLM)的规模不断增长,部署这些模型面临着巨大的计算和资源挑战。以DeepSeek-R1为例,其671B参数的规模即使经过INT4量化后,仍需要至少6张高端GPU才能运行,这对于大多数中小型企业和研究机构来说成本过高。知识蒸馏作为一种有效的模型压缩技术,通过将大型教师模型的知识迁移到小型学生模型中,在显著降低模型复杂度的同时保留核心性能,成为解决这一问题的关键技术之一。
540 6
|
8月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
466 9

热门文章

最新文章

推荐镜像

更多