NLP中的知识蒸馏

本文涉及的产品
智能开放搜索 OpenSearch行业算法版,1GB 20LCU 1个月
实时计算 Flink 版,5000CU*H 3个月
实时数仓Hologres,5000CU*H 100GB 3个月
简介: 模型参数量决定其捕获知识的能力,但非线性关系;相同参数量的模型因训练方法差异,如早停法,捕获知识不同。知识蒸馏通过老师模型指导学生模型,利用软标签传递更多信息,提高学生模型性能。温度调控softmax输出,平衡信息获取与噪声抑制,小模型适合较低温度以聚焦关键知识。
  1. 模型的参数量与其所能捕获到的知识量的关系
    一个模型的参数量基本决定了其所能捕获到的知识量。这样的想法是基本正确的,但是需要注意的是

  2. 模型的参数量和其所能捕获的知识量之间并非稳定的线性关系(下图中的1),而是接近边际收益逐渐减少的一种增长曲线(下图中的2和3)

  3. 完全相同的模型架构和模型参数量,使用完全相同的训练数据,能捕获的知识量并不一定完全相同,其还会受到训练的方法的影响(early stopping)。

合适的训练方法可以使得在模型参数总量比较小时,尽可能地获取到更多的知识(下图中的3与2曲线的对比)image.png

  1. 老师模型与学生模型
    知识蒸馏的过程分为2个阶段:

老师模型相对复杂,也可以由多个分别训练的模型集成而成。我们对老师模型不作任何关于模型架构、参数量、是否集成方面的限制,唯一的要求就是,对于输入X, 其都能输出Y,其中Y经过softmax的映射,输出值对应相应类别的概率值。

学生模型是参数量较小、模型结构相对简单的单模型。同样的,对于输入X,其都能输出Y,Y经过softmax映射后同样能输出对应相应类别的概率值。

在本文中,将问题限定在分类问题下,该类问题的共同点是模型最后会有一个softmax层,其输出值对应了相应类别的概率值。

  1. 什么是知识蒸馏
    首先,蒸馏的一个重要目的是让学生模型学习到老师模型的泛化能力。

一个很高效的蒸馏方法就是使用老师网络softmax层输出的类别概率来作为软标签,和学生网络的softmax输出做交叉熵。

传统训练方法是硬标签,正类是1,其他所有负类都是0。

知识蒸馏的训练过程过程是软标签,用老师模型的类别概率作为软标签。

  1. 为什么知识蒸馏有效
    softmax层的输出,除了正类别之外,负类别也带有大量的信息,比如某些负类别对应的概率大于其他负类别。而在传统的训练过程(硬标签)中,所有负类别都被统一对待,都是0。也就是说,知识蒸馏的训练方式使得学生网络获取的信息量大于传统的训练方法。

例子(手写体数字识别任务)

假设某个输入的”2”更加形似"3",softmax的输出值中"3"对应的类别概率为0.1,而其他负类别对应的值都很小,而另一个输入"2"更加形似"7","7"对应的概率为0.1。

这两个"2"对应的硬标签的值是相同的,但是它们的软标签却是不同的,由此我们可见软标签蕴含着比硬标签多的信息。并且软标签的熵较高时,软标签蕴含的知识就更丰富。image.png

  1. 知识蒸馏中的softmax
    原始的softmax函数

image.png
要是直接使用softmax的输出值作为软标签, 这又会带来一个问题: 当softmax后的概率分布的熵相对较小时,也就是负标签的概率都很接近0,那么它们对损失函数的影响都非常小。

加了温度之后的softmax函数image.png原来的softmax函数是T = 1的特例。T越高,softmax的输出概率分布越趋于平滑,其分布的熵越大,负标签携带的信息会被相对地放大,模型训练将更加关注负标签。

一般地,知识蒸馏中,会把温度设置为>1。

  1. 在知识蒸馏中如何选择合适的softmax温度
    温度的高低改变的是学生网络对负标签的关注程度。

  2. 温度较低时,负类别携带的信息会被相对减少,对负类别的关注较少,负类别的概率越低,关注越少。

  3. 温度较高时,负类别的概率值会相对增大,负类别携带的信息会被相对地放大,学生网络会更多关注到负标签。

实际上,负类别中包含一定的信息,尤其是那些概率值较高的负类别。但由于老师网络的负类别可能会有噪声,并且负类别的概率值越低,其信息就越不可靠。因此温度的选取比较看经验,本质上就是在下面两件事之中取舍

  1. 从负类别中获取信息 --> 温度要高一些

  2. 防止受负类别中噪声的影响 --> 温度要低一些

总的来说,温度的选择和学生网络的大小有关,学生网络参数量比较小的时候,相对比较低的温度就可以了,因为参数量小的模型不能捕获所有知识,所以可以适当忽略掉一些负标签的信息。

目录
相关文章
|
29天前
|
机器学习/深度学习 自然语言处理 并行计算
探索深度学习中的Transformer模型及其在自然语言处理中的应用
【10月更文挑战第6天】探索深度学习中的Transformer模型及其在自然语言处理中的应用
79 0
|
1月前
|
机器学习/深度学习 自然语言处理 异构计算
【NLP自然语言处理】初识深度学习模型Transformer
【NLP自然语言处理】初识深度学习模型Transformer
|
3月前
|
机器学习/深度学习 自然语言处理 数据挖掘
【NLP】深度学习的NLP文本分类常用模型
本文详细介绍了几种常用的深度学习文本分类模型,包括FastText、TextCNN、DPCNN、TextRCNN、TextBiLSTM+Attention、HAN和Bert,并提供了相关论文和不同框架下的实现源码链接。同时,还讨论了模型的优缺点、适用场景以及一些优化策略。
85 1
|
3月前
|
机器学习/深度学习 自然语言处理
自然语言处理 Paddle NLP - 预训练语言模型及应用
自然语言处理 Paddle NLP - 预训练语言模型及应用
27 0
|
5月前
|
自然语言处理 数据挖掘
【自然语言处理NLP】Bert中的特殊词元表示
【自然语言处理NLP】Bert中的特殊词元表示
68 3
|
5月前
|
自然语言处理 监控 物联网
自然语言处理(NLP)微调
自然语言处理(NLP)微调
60 0
|
6月前
|
存储 机器学习/深度学习 自然语言处理
Transformer 自然语言处理(二)
Transformer 自然语言处理(二)
236 0
Transformer 自然语言处理(二)
|
6月前
|
存储 机器学习/深度学习 自然语言处理
Transformer 自然语言处理(四)
Transformer 自然语言处理(四)
330 0
Transformer 自然语言处理(四)
|
6月前
|
存储 自然语言处理 PyTorch
Transformer 自然语言处理(三)
Transformer 自然语言处理(三)
160 0
Transformer 自然语言处理(三)
|
6月前
|
机器学习/深度学习 自然语言处理 PyTorch
Transformer 自然语言处理(一)
Transformer 自然语言处理(一)
230 0
Transformer 自然语言处理(一)