PyTorch深度学习实战 | Transformer模型初识

在线体验各类最新模型,更有模型 免费Token 额度领取!
立即体验
简介: 本文介绍了Transformer模型在机器翻译任务中的工作原理。主要内容包括:1)模型分为编码器(处理英文输入)和解码器(生成中文输出)两部分;2)推理时采用自回归模式,逐步生成翻译结果;3)训练时使用教师强制模式,始终以真实标签指导模型学习。文章详细阐述了编码器-解码器结构的工作流程,包括词嵌入、多头注意力机制等核心组件,以及如何通过残差连接和层标准化优化模型性能。最后解释了编码器和解码器三个关键计算步骤的张量维度变化过程。

 专栏介绍

       这个专栏下我将会写5篇文章帮助大家更好的理解Transformer

【1】Transformer模型初识

【2】词嵌入和位置编码

【3】多头注意力机制

【4】层归一化层和FeedForward

【5】手动计算和完整的代码实现

宏观理解Transformer

    详细很多人都听过或者用过这个模型,你真的掌握Transformer了嘛?让我们来看一下下面这

两个问题。

问题一解释下面这两个概念:

Transformer 模型的工作模式:

(1)训练模式,教师强制,Teacher Forcing 模式

(2)推理模式,自回归,Autoregressive 模式

问题二,在下图解码器的这个部分中,回答谁是Q,谁是K,谁是V

image.gif

基于 “英译汉” 机器翻译任务快速认识 Transformer

     现在我们来假想一个英译汉的机器翻译任务。先把 Transformer 模型看作一个整体:假设待翻

译的英文数据是 “Are you OK?”,对应的中文标注数据是 “你好吗?”,经过模型的推理

计算后,最终输出的翻译结果为 “你干什么?”。这个模型核心由两部分构成 —— 编码器与解码

器。

image.gif

Transformer 的推理模式即自回归,Autoregressive 模式

   模型已经完成训练后,我们就可以利用训练好的参数对未知数据进行推理计算。比如在这个翻译

任务中,我们会输入全新的英文句子,让模型直接输出对应的中文翻译结果 —— 而推理过程中,

我们并不知道这句英文的标准答案(也就是之前提到的、没有标注的中文数据)。那么,模型该如

何一步步生成这个 “未知的中文翻译结果” 呢?

image.gif

自回归推理模式本质是一个循环生成过程。在第一次循环中,待翻译的英文句子 “Are you OK?”

会从模型左下方输入,而模型右下方不会有完整的中文标注数据输入 —— 最初只会传入一个特殊

起始符号<start>,以此提示模型:中文翻译任务正式启动。

image.gif

    Transformer 的编码器负责接收并处理左侧的英文数据,解码器则接收并处理特殊起始符号

<start>。在二者的协同作用下,模型完成第一次推理:若模型计算准确,其右上角会输出翻译结

果的第一个字 “你”。由于输入的英文内容固定不变,左侧编码器仅需进行一次计算;到第二次推

理时,解码器会接收并处理起始符号<start>和已生成的 “你”,进而完成第二次推理,生成翻译结

果的第二个字 “好”

image.gif

由于左侧编码器只需完成一次计算,后续推理可直接复用其输出的编码结果,因此我们可以将编码

器省略,把它的编码结果视为一个固定向量传入解码器。接下来模型会继续按上述逻辑循环 n 轮:

每一轮解码器都会接收之前生成的所有中文 tokens 与起始符号<start>,逐步生成后续内容,直到

右上方输出特殊结束符号<end>,便意味着整个翻译过程完成。

image.gif

基于此前已生成的结果,持续推理后续的输出内容。模型逐步生成输出序列的过程中,自回归的推

理模型如图所示。

image.gif

Transformer 的训练模式即教师强制,Teacher Forcing 模式

    我们先把 Transformer 看作一个基础的深度学习模型。训练这类模型的核心逻辑是:计算模型

的输出结果与人工标注的标准答案之间的误差,再通过优化算法找到能让损失函数最小化的模型参

数从整体视角看 Transformer 的前向传播过程:输入是英文句子 “Are you OK?”预期输出是中

文标注 “你好啊”。但 Transformer 有个特殊设计 —— 人工标注的 “你好啊”,除了作为计算误差的

参考,还需要作为解码器的输入参与训练。

          之前 “模型将英文与中文标注一起推理出‘你干什么’” 的例子,其实没法体现 Transformer 真

实的训练过程。在教师强制(Teacher Forcing)的训练模式下,模型输出的字数必须和人工标注

的真实数据长度完全一致 —— 比如标注数据是 “你干嘛”(3 个字),模型就只能预测出 3 个字的

结果,而不会是长度不同的 “你干什么”。接下来计算损失函数、反向传播求解梯度,再基于梯度

更新模型参数,到这里就完成了一次训练迭代。

image.gif

教师强制模式的前向传播,核心是根据不同长度的真实中文标签,分步骤进行推理计算。

举个具体例子:假设待翻译的英文是 “Are you OK?”,对应的中文标注是 “你好吗?”(5 个字)。

我们会先给标注数据添加特殊起始符号<start>,再构造 5 组训练数据 —— 每组数据的时间步 t,

都与解码器输入的标签片段长度完全一致。

   比如在第二个时间步,若解码器输出的 y3 本应对应 “好” 字,却出现了预测错误,我们会基于这

个误差优化损失函数。关键在于,即便第二个时间步的预测结果有误,到了第三个时间步,解码器

也不会使用前一步的错误输出,而是依然强制接收真实的标签数据作为输入 —— 这种 “无论前序

预测是否正确,都以真实标签引导后续训练” 的方式,就是教师强制(Teacher Forcing)模式的核

心。

image.gif

整体而言,“Are you OK” 会被输入到编码器进行编码处理。而输入给解码器的数据,从结构上可看

作下三角形式 —— 这意味着在训练过程中,每个时间步的解码器输入仅包含此前的真实标签(不

包含当前及后续的标签信息)。模型的预测结果是一个列向量y,我们需要将其与人工标注的标签

向量y进行比对,进而计算所有时间步的损失并反向传播。值得注意的是,解码器的输入本质上仍

是一整行数据,只是会通过因果掩码矩阵在注意力计算阶段发挥作用,确保每个位置的预测仅依赖

于之前的信息,避免未来信息的 “泄露”。

image.gif

image.gif


Transformer的工作的架构设计和工作流程

整体包括:

编码器 - 解码器结构 (Encoder-Decoder)

词向量 (Embedding)和位置编码 (Positional Encoding)

多头注意力机制 (Multi-Head-Attention)

前馈神经网络 (Feed Forward)

残差连接 (Add) 和层标准化 (Norm)

线性层 (Linear) 和 softmax 层

image.gif

词向量 (Embedding)和位置编码 (Positional Encoding)

词的最终表示等于词向量表示加上位置编码

编码器和解码器结构

有多组编码器和解码器串联堆叠的工作,主要包含了3个计算

第一个计算

     编码器基于:自注意力机制 (Multi-Head Attention)前馈神经网络 (Feed Forward)对 “英文的待

译数据” 进行编码。

待译英文的张量 x1 的尺寸为 [B, Q1, D]:

B 个样本每个样本包括 Q1 个单词每个单词的嵌入维度是 D

经过解码器之后(可以看出提取特征),输出的尺寸是[B, Q1, D]

第二个计算

     解码器基于:带有掩码的 (Masked)自注意力机制 (Multi-Head Attention)对 “中文的标注数据”

进行编码。(也可以看成提取特征)

中文标注的张量 x2 的尺寸为 [B, Q2, D]:

B 个样本每个样本包括 Q2 个单词每个单词的嵌入维度是 D

(对于 x1 和 x2:

1. 包含的样本个数,因此具有相同的批量大小 B

2. 英文和中文的表达方式不同,所以句子的长度不同

3. 嵌入维度 D 在通常情况下相同,保证英文和中文在特征的维度上对齐,使模型的设计更加简洁)输出的尺寸是[B, Q2, D]

第三个计算

     解码器的第 2 个注意力层:编码器 - 解码器注意力层(Multi-Head Attention),对两组数据一起

解码得到解码器的最终输出。通过编码器和解码器中的 “自注意力机制”:将输入数据中的全局信息

计算并附加到解码结果中,输出的尺寸是[B, Q2, D]

x3 的单词数量会与 x2 相同;因为将英文 x1 翻译为中文 x2,所以翻译结果 x3 的长度,和中文标

注 x2 的长度相同。

目录
相关文章
|
24天前
|
机器学习/深度学习 数据采集 人工智能
田间杂草检测数据集分享(适用于YOLO系列深度学习分类检测任务)
本数据集含4000张真实农田图像(小麦/玉米/水稻田),YOLO格式标注杂草目标,覆盖多天气、光照与视角,适用于YOLO系列等目标检测模型训练,助力智能除草与精准农业研究。(239字)
346 16
|
24天前
|
数据采集 存储 算法
视频 RAG 中分块策略:基于停顿、滑动窗口与基于 LLM 的方法
本文探讨视频RAG中的核心挑战——如何为无时间结构的视频转录文本设计有效分块策略。对比传统文本分块,提出基于停顿、重叠窗口、递归切分及LLM驱动的主题分块四层方案,实现细粒度检索与全局理解兼顾,提升视频内容检索准确性与上下文完整性。
177 13
视频 RAG 中分块策略:基于停顿、滑动窗口与基于 LLM 的方法
|
24天前
|
安全 JavaScript 前端开发
《ZAKU渗透论:卓伊凡的2026渗透工程》第四章:Web攻击原理(下)——XSS、CSRF、文件上传漏洞
本章详解XSS、CSRF与文件上传三大Web漏洞:XSS通过注入恶意脚本窃取Cookie;CSRF伪造已登录用户请求执行非自愿操作;文件上传漏洞则因校验缺失致服务器被控。三者共性——过度信任用户输入。(239字)
335 10
|
24天前
|
SQL 安全 程序员
《ZAKU渗透论:卓伊凡的2026渗透工程》第三章:Web攻击原理(上)——注入与SQL注入
本章详解Web攻击核心——注入与SQL注入。通过“小明输入‘小明’OR‘1’=‘1’秒变管理员”的生动案例,揭示攻击本质:程序混淆数据与代码,导致恶意SQL被执行。深入剖析万能密码、数据窃取、权限绕过等危害,并指出漏洞长期存在的根源:历史代码、意识不足与修复成本。
277 2
|
24天前
|
安全 NoSQL Java
《ZAKU渗透论:卓伊凡的2026渗透工程》信息收集——黑客怎么找到你?
本章详解渗透测试中至关重要的信息收集环节:占全程50%以上工作量。涵盖被动(搜索引擎、GitHub、社交媒体、Whois、历史快照)与主动(DNS查询、子域名枚举、端口扫描、目录探测)两大策略,并聚焦2026年新趋势——供应链踩点。目标是绘制精准“攻击地图”,找到阻力最小的突破口。(239字)
247 2
|
24天前
|
人工智能 API 语音技术
阿里云百炼CLI是什么?如何安装使用百炼CLI命令行工具?
阿里云百炼CLI是百炼AI大模型平台的命令行工具,支持全模态对话、图像/视频生成与编辑、语音合成识别、联网搜索、知识库检索等10+能力,深度集成AI Agent与Skills技能扩展,助力高效开发电商图、播客等内容。在阿里云百炼平台快速体验:https://t.aliyun.com/U/fPVHqY
309 0
|
24天前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch深度学习实战 |层归一化层和FeedForward
本文介绍了PyTorch深度学习中Add&amp;Norm层和FeedForward层的实现原理。Add&amp;Norm层由残差连接(Add)和层归一化(Norm)组成,能加速模型收敛并稳定训练。层归一化会对神经网络每层的输出进行归一化处理,文中详细展示了其计算方法和PyTorch实现代码。FeedForward层是一个两层的全连接网络,通过线性变换提取更深层次特征。文章还分析了Transformer模型中使用层归一化的原因,并提供了完整的代码实现,包括参数初始化和前向传播过程。
126 0
|
24天前
|
机器学习/深度学习 编解码 算法
PyTorch深度学习实战 |手算​​U-net
本文详细解析了U-Net网络架构及其在医学图像分割中的应用。重点对比了U-Net与FCN的核心区别:U-Net采用特征拼接(Concat)保留所有层级信息,而FCN使用特征相加(Add)进行融合。文章深入剖析了U-Net的编码器-瓶颈-解码器结构,解释了其独特的裁剪拼接机制和Overlap-tile策略,并提供了完整的PyTorch实现代码。现代U-Net通过SamePadding实现了输入输出尺寸一致,显著提升了分割精度。文章还探讨了弹性形变数据增强和带空间权重的损失函数设计,为医学图像分析提供了实用解决
162 2
|
24天前
|
机器学习/深度学习 PyTorch 测试技术
PyTorch深度学习实战 |多头注意力机制
摘要:本文详细介绍了Transformer中的多头自注意力机制,从整体结构到实现细节,包括四部分内容:(1)多头自注意力的基本架构;(2)内部计算流程解析;(3)注意力计算公式详解;(4)代码实现。重点阐述了多头并行的计算方式、缩放点积注意力的计算步骤(QK转置、缩放、softmax和加权求和),以及残差连接和层归一化的作用。通过&quot;Are you OK?&quot;示例展示了输入张量如何经过8个64维注意力头处理后拼接成512维输出。文章最后提供了完整的PyTorch实现代码,并附测试用例验证模型
147 0
|
24天前
|
机器学习/深度学习 数据挖掘 PyTorch
PyTorch深度学习实战 |手算​​FCN全卷积神经网络
本文介绍了FCN-8s语义分割网络的实现细节。首先解释了语义分割的概念及其与图像分类的区别,重点分析了FCN网络结构中的全卷积化、上采样和跳跃连接三个关键技术。全卷积化将传统CNN的全连接层改为卷积层,实现像素级分类;上采样通过双线性插值恢复特征图尺寸;跳跃连接则融合高低层特征以提升细节表现。文章详细推导了损失函数的计算过程,并提供了完整的PyTorch实现代码,包括双线性插值权重初始化、VGG16骨干网络和FCN-8s主体结构。最后通过测试验证了模型能正确输出与输入尺寸匹配的预测结果。
223 3