Bert Pytorch 源码分析:四、编解码器

简介: Bert Pytorch 源码分析:四、编解码器
# Bert 编码器模块
# 由一个嵌入层和 NL 个 TF 层组成
class BERT(nn.Module):
    """
    BERT model : Bidirectional Encoder Representations from Transformers.
    """
    def __init__(self, vocab_size, hidden=768, n_layers=12, attn_heads=12, dropout=0.1):
        """
        :param vocab_size: vocab_size of total words
        :param hidden: BERT model hidden size
        :param n_layers: numbers of Transformer blocks(layers)
        :param attn_heads: number of attention heads
        :param dropout: dropout rate
        """
        super().__init__()
    # 嵌入大小 ES
        self.hidden = hidden
    # TF 层数 NL
        self.n_layers = n_layers
    # 头部数量 HC
        self.attn_heads = attn_heads
        # FFN 层中的隐藏单元数量,记为 FF,一般是 ES 的四倍
        self.feed_forward_hidden = hidden * 4
        # 嵌入层,嵌入矩阵尺寸 VS * ES
        self.embedding = BERTEmbedding(vocab_size=vocab_size, embed_size=hidden)
        # NL 个 TF 层
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden, attn_heads, hidden * 4, dropout) for _ in range(n_layers)])
    def forward(self, x, segment_info):
        # 为`<pad>`(ID = 0)设置掩码
    # 尺寸为 BS * 1 * ML * ML,以便与相似性矩阵 S 匹配
    # 在每个 BS 的 ML * ML 矩阵中,`<pad>`标记对应的行为 1,其余为零
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        # 单词 ID 传入嵌入层得到词向量
        x = self.embedding(x, segment_info)
        # 依次传入每个 TF 层,得到编码器输出
        for transformer in self.transformer_blocks:
            x = transformer.forward(x, mask)
        return x
# 解码器结构根据具体任务而定
# 任务一般有三种:(1)序列分类,(2)标记分类,(3)序列生成
# 但一般都是全连接的
# 用于下个句子判断的解码器
# 序列分类任务,输入两个句子,输出一个标签,1表示是相邻句子,0表示不是
class NextSentencePrediction(nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """
    def __init__(self, hidden):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
    # 将向量压缩到两维, 尺寸为 ES * 2
        self.linear = nn.Linear(hidden, 2)
        self.softmax = nn.LogSoftmax(dim=-1)
    def forward(self, x):
    # 输入 -> 取第一个向量 -> LL -> softmax -> 输出
    # 输出相邻句子和非相邻句子的概率
        return self.softmax(self.linear(x[:, 0]))
# 用于完型填空的解码器
# 序列生成任务,输入是带有`<mask>`的句子,输出是完整句子
class MaskedLanguageModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """
    def __init__(self, hidden, vocab_size):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
    # 将输入压缩到词汇表大小
        self.linear = nn.Linear(hidden, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)
    def forward(self, x):
    # 输入 -> LL -> softmax -> 输出
    # 输出序列中每个词是词汇表中每个词的概率
        return self.softmax(self.linear(x))
相关文章
|
7月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图 REV1
Bert Pytorch 源码分析:五、模型架构简图 REV1
102 0
|
7月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:三、Transformer块
Bert Pytorch 源码分析:三、Transformer块
85 0
|
7月前
|
PyTorch 算法框架/工具 C++
Bert Pytorch 源码分析:二、注意力层
Bert Pytorch 源码分析:二、注意力层
112 0
|
7月前
|
PyTorch 算法框架/工具
Bert Pytorch 源码分析:五、模型架构简图
Bert Pytorch 源码分析:五、模型架构简图
77 0
|
2月前
|
算法 PyTorch 算法框架/工具
Pytorch学习笔记(九):Pytorch模型的FLOPs、模型参数量等信息输出(torchstat、thop、ptflops、torchsummary)
本文介绍了如何使用torchstat、thop、ptflops和torchsummary等工具来计算Pytorch模型的FLOPs、模型参数量等信息。
364 2
|
20天前
|
机器学习/深度学习 人工智能 PyTorch
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
本文探讨了Transformer模型中变长输入序列的优化策略,旨在解决深度学习中常见的计算效率问题。文章首先介绍了批处理变长输入的技术挑战,特别是填充方法导致的资源浪费。随后,提出了多种优化技术,包括动态填充、PyTorch NestedTensors、FlashAttention2和XFormers的memory_efficient_attention。这些技术通过减少冗余计算、优化内存管理和改进计算模式,显著提升了模型的性能。实验结果显示,使用FlashAttention2和无填充策略的组合可以将步骤时间减少至323毫秒,相比未优化版本提升了约2.5倍。
35 3
Transformer模型变长序列优化:解析PyTorch上的FlashAttention2与xFormers
|
2月前
|
机器学习/深度学习 自然语言处理 监控
利用 PyTorch Lightning 搭建一个文本分类模型
利用 PyTorch Lightning 搭建一个文本分类模型
69 8
利用 PyTorch Lightning 搭建一个文本分类模型
|
2月前
|
机器学习/深度学习 自然语言处理 数据建模
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
本文深入探讨了Transformer模型中的三种关键注意力机制:自注意力、交叉注意力和因果自注意力,这些机制是GPT-4、Llama等大型语言模型的核心。文章不仅讲解了理论概念,还通过Python和PyTorch从零开始实现这些机制,帮助读者深入理解其内部工作原理。自注意力机制通过整合上下文信息增强了输入嵌入,多头注意力则通过多个并行的注意力头捕捉不同类型的依赖关系。交叉注意力则允许模型在两个不同输入序列间传递信息,适用于机器翻译和图像描述等任务。因果自注意力确保模型在生成文本时仅考虑先前的上下文,适用于解码器风格的模型。通过本文的详细解析和代码实现,读者可以全面掌握这些机制的应用潜力。
121 3
三种Transformer模型中的注意力机制介绍及Pytorch实现:从自注意力到因果自注意力
|
3月前
|
机器学习/深度学习 PyTorch 调度
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
在深度学习中,学习率作为关键超参数对模型收敛速度和性能至关重要。传统方法采用统一学习率,但研究表明为不同层设置差异化学习率能显著提升性能。本文探讨了这一策略的理论基础及PyTorch实现方法,包括模型定义、参数分组、优化器配置及训练流程。通过示例展示了如何为ResNet18设置不同层的学习率,并介绍了渐进式解冻和层适应学习率等高级技巧,帮助研究者更好地优化模型训练。
204 4
在Pytorch中为不同层设置不同学习率来提升性能,优化深度学习模型
|
3月前
|
机器学习/深度学习 监控 PyTorch
PyTorch 模型调试与故障排除指南
在深度学习领域,PyTorch 成为开发和训练神经网络的主要框架之一。本文为 PyTorch 开发者提供全面的调试指南,涵盖从基础概念到高级技术的内容。目标读者包括初学者、中级开发者和高级工程师。本文探讨常见问题及解决方案,帮助读者理解 PyTorch 的核心概念、掌握调试策略、识别性能瓶颈,并通过实际案例获得实践经验。无论是在构建简单神经网络还是复杂模型,本文都将提供宝贵的洞察和实用技巧,帮助开发者更高效地开发和优化 PyTorch 模型。
54 3
PyTorch 模型调试与故障排除指南