CVPR 2022|解耦知识蒸馏,让Hinton在7年前提出的方法重回SOTA行列(1)

简介: CVPR 2022|解耦知识蒸馏,让Hinton在7年前提出的方法重回SOTA行列
与主流的feature蒸馏方法不同,本研究将重心放回到logits蒸馏上,提出了一种新的方法「解耦知识蒸馏」,重新达到了SOTA结果,为保证复现该研究还提供了开源的蒸馏代码库:MDistiller

1 研究摘要


近年来顶会的 SOTA 蒸馏方法多基于 CNN 的中间层特征,而基于输出 logits 的方法被严重忽视了。饮水思源,本文中来自旷视科技 (Megvii)、早稻田大学、清华大学的研究者将研究重心放回到 logits 蒸馏上,对 7 年前 Hinton 提出的知识蒸馏方法(Knowledge Distillation,下文简称 KD)[1] 进行了解耦和分析,发现了一些限制 KD 性能的重要因素,进而提出了一种新的方法「解耦知识蒸馏」(Decoupled Knowledge Distillation,下文简称 DKD)[2],使得 logits 蒸馏重回 SOTA 行列。

同时,为了保证复现和支持进一步研究,该研究提供了一个全新的开源代码库 MDistiller,该库涵盖了 DKD 和大部分的 SOTA 方法,并不断进行更新维护,欢迎大家试用并提供宝贵的反馈意见。




2 研究动机


上图是大家熟知的 KD 方法,KD 用 Teacher 网络和 Student 网络的输出 logits 来计算 KL Loss,从而实现 dark knowledge 的传递,利用 Teacher 已经学到的知识帮助 Student 收敛得更好。在 KD 之后,更多的基于中间特征的蒸馏方法不断涌现,不断刷新知识蒸馏的 SOTA。但该研究认为,KD 这样的 logits 蒸馏方法具备两点好处:

1. 基于 feature 的蒸馏方法需要更多复杂的结构来拉齐特征的尺度和网络的表示能力,而 logits 蒸馏方法更简单高效;2. 相比中间 feature,logits 的语义信息是更 high-level 且更明确的,基于 logits 信号的蒸馏方法也应该具备更高的性能上限,因此,对 logits 蒸馏进行更多的探索是有意义的。

该研究尝试一种拆解的方法来更深入地分析 KD:将 logits 分成两个部分(如图),蓝色部分代表目标类别(target class)的 score,绿色部分代表非目标类别(Non-target class)的 score。这样的拆解使得我们可以重新推导 KD 的 Loss 公式,得到一个新的等价表达式,进而做更多的实验和分析。

2.1 符号定义

这里只写出关键符号定义,更具体的定义请参考论文正文。

首先,该研究将第 i 类的分类概率表示为(其中表示网络输出的 logits):


为了拆解分类网络输出的 logits,该研究接下来定义了两种新的概率分布

1. 目标类 vs 非目标类的二分类分布,该概率分布和分类监督信号高度耦合。该分布包含两个元素:目标类概率和全部非目标类概率,分别表示为:


2. 非目标类内部竞争的多分类分布,也就是在预测样本为非目标类的前提下每个类各自的概率(总和为 1)。这个概率分布和分类的监督信号是不相关的,换句话说,从这个概率分布中无法得知目标类上的预测置信度,其表达式为:


根据上述定义,可以得到一个显然的数学关系:这些定义和数学关系将帮助我们得到 KD Loss 的一个新的表达形式。

2.2 重新推导 KD Loss

首先,KD 的 Loss 定义如下:


然后根据公式(1)和(2),我们可以将其改写为:


可以观察到,式中的第一项只牵涉到了目标类别 vs 非目标类别的二分类概率分布,第二项牵涉到了非目标类概率分布的 KL 散度和权重该研究将第一项命名为目标类别知识蒸馏 Target Class Knowledge Distillation(下文简称 TCKD),将第二项中的 KL 散度命名为非目标类别知识蒸馏 Non-target Class Knowledge Distillation(下文简称 NCKD)。至此,该研究完成了对 KD Loss 的拆分,将其分成了两个可单独使用的部分,并可以分析其各自的作用:


3 启发式探索

首先,该研究对 TCKD 和 NCKD 做了消融实验,观察它们对蒸馏性能的影响;接着,他们分别探索 TCKD 和 NCKD 的作用;最后,研究者做了一些启发式的讨论。

3.1 单独使用 TCKD/NCKD 训练


如表 1 所示,我们可以观察到:

1. 同时使用 TCKD 和 NCKD(等同于 KD),有不错的性能提升;2. 单独使用 TCKD 进行蒸馏,会对蒸馏效果产生较大的损害(这一点在补充材料中有详细讨论,主要和蒸馏温度 T 相关);3. 单独使用 NCKD 进行蒸馏,和 KD 的效果是差不多的,甚至有时会更好;

基于这些观察可以推出两个初步结论:

1.TCKD 是没用的,甚至在单独使用时可能是有害的;2.NCKD 可能是 KD 生效的主要原因;

接下来该研究就这两个初步的结论进行了进一步的分析。

3.2 TCKD:传递样本难度相关的知识

TCKD 作用于目标类的二分类概率分布上,这个概率的物理含义是「网络对样本的置信度」。比如:如果一个样本被 Teacher 学会了,会产生类似[0.99, 0.01] 的 binary 概率,而如果一个样本比较难拟合,则会产生类似 [0.6, 0.4] 的 binary 概率。所以该研究猜测:TCKD 传递了和样本拟合难度相关的知识,当训练集拟合难度高时才会起到作用。为了证明这一点,该研究设计了三组实验来增加 CIFAR-100 的训练难度,观察 TCKD 是否有效:

更强的数据增广:


以表 2 中的 ShuffleNet-V1 为例,在使用 AutoAugment 的情况下,训练集难度有了明显提升,此时仅仅使用 NCKD 只能达到 73.8% 的 student 准确率,而同时使用 TCKD 和 NCKD 可以将 student 准确率提升至 75.3%。

更 Noisy 的标签:


表 3 中,该研究通过控制 noisy ratio 对数据集的标签引入不同程度噪声,ratio 越大表示噪声越大。可以看到,随着数据集的噪声变大,单独使用 NCKD 的效果变得越来越差,同时引入 TCKD 的增益也越来越大。说明在越难学的数据上,TCKD 的作用就会越明显。

更难的数据集:


ImageNet 是一个比 CIFAR-100 更困难的数据集,所以该研究在 ImageNet 上也进行了尝试。从表 4 可以看出,在 ImageNet 上只使用 NCKD 的效果也是没有同时使用 TCKD 和 NCKD 要好的。

总结

三组实验都反映出,当训练数据拟合难度变高时(无论是数据本身难度、还是噪声和增广带来的难度),TCKD 能提供更有效的知识,对蒸馏性能的提升也越高,这些实验在一定程度上说明了 TCKD 确实是在传递有关样本拟合难度的知识,印证了该研究的想法。

3.3 NCKD:被抑制的重要成分

表 1 中反映出的另一个有趣的现象是:只使用 NCKD 也能取得令人满意的蒸馏效果,甚至可能比 KD 更好。这样的现象反映出:非目标类别上的 logits 中蕴含的信息,才是最主要的 dark knowledge 成分。

然而当回顾 KD 的新表达式时,发现 NCKD 对应的 loss 是和权重耦合在一起的。换言之,如果 teacher 网络的预测越置信,NCKD 的 loss 权重就更低,其作用就会越小。而该研究认为,teacher 更置信的样本能够提供更有益的 dark knowledge,和 NCKD 耦合的权重会严重抑制高置信度样本的知识迁移,使得知识蒸馏的效率大幅降低。为了证明这一点,该研究做了如下实验:

1. 依据 teacher 模型的置信度,该研究对训练集上的样本做了排序,并将排序后的样本分成置信(置信度 top-50%)和非置信 (剩余) 两个批次;2. 训练时,对全部样本使用分类 Loss,并只对置信批次 / 非置信批次使用 NCKD Loss;


实验结果如表 5 所示,0-50% 表示置信批次,50-100% 表示非置信批次。第一行是在整个训练集上做 NCKD 的结果,第二行表示只对置信批次做 NCKD,第三行表示只对非置信批次做 NCKD。显然,置信批次上使用 NCKD 带来了更主要的涨点,说明置信度更高的样本对蒸馏的训练过程是更有益的,因此是不应该被抑制的。

3.4 启发

至此,该研究完成了对 KD Loss 的解耦,并且分析了两个部分各自的作用。所有结果都表明,TCKD 和 NCKD 都有自己的重要作用,然而,研究注意到了在原始的 KD Loss 中,TCKD 和 NCKD 是存在不合理的耦合的:

1. 一方面,NCKD 和耦合,会导致高置信度样本的蒸馏效果大打折扣;2. 另一方面,TCKD 和 NCKD 是耦合的。然而这两个部分传递的知识是不同的,这样的耦合导致了他们各自的重要性没有办法灵活调整。

4 Decoupled Knowledge Distillation


根据推导和启发式探索,该研究提出了一种新的 logits 蒸馏方法“解耦知识蒸馏(DKD)”,来解决上一章提出的两个问题,如上图所示。DKD 的 Loss 表达式如下:


和 KD Loss 相比,该研究将限制 NCKD 的权重替换为了,并给 TCKD 设置了一个权重DKD 可以很好地解决刚才提到的两个问题:一方面,TCKD 和 NCKD 被解耦,它们各自的重要性可以独立调节;另一方面,对于蒸馏更重要的 NCKD 也不会再被 Teacher 产生的高置信度抑制,大大提高了蒸馏的灵活性和有效性。DKD 的伪代码如下:



相关文章
|
存储 JSON 网络协议
微服务Consul集群搭建
Consul是HashiCorp的开源工具,用于服务发现、配置管理和分布式一致性。它提供服务注册与发现、健康检查、KV存储、多数据中心支持,并基于Raft协议保证一致性。Consul还具有DNS接口和Web UI。要安装,可从HashiCorp或阿里云下载,使用`yum`在Linux上安装。启动单机模式用`consul agent -dev`,集群部署涉及配置文件如`/etc/consul.d/consul.hcl`。常用命令包括启动、加入集群、查看成员及服务管理等。
微服务Consul集群搭建
查看 npm 包下载量(简单快捷,数据精确)
查看 npm 包下载量(简单快捷,数据精确)
1402 0
|
自然语言处理 算法 数据挖掘
自蒸馏:一种简单高效的优化方式
背景知识蒸馏(knowledge distillation)指的是将预训练好的教师模型的知识通过蒸馏的方式迁移至学生模型,一般来说,教师模型会比学生模型网络容量更大,模型结构更复杂。对于学生而言,主要增益信息来自于更强的模型产出的带有更多可信信息的soft_label。例如下右图中,两个“2”对应的hard_label都是一样的,即0-9分类中,仅“2”类别对应概率为1.0,而soft_label
自蒸馏:一种简单高效的优化方式
|
8月前
|
机器学习/深度学习 人工智能 弹性计算
数字产科平台构建方案
数字产科平台构建方案:基于云计算与国产化技术,集成AI、物联网与RPA,实现孕产妇全周期智能管理。涵盖自助建档、高危预警、远程监护、智能宣教等功能。
278 3
|
7月前
|
监控 Java BI
《深入理解Spring》定时任务——自动化调度的时间管理者
Spring定时任务通过@Scheduled注解和Cron表达式实现灵活调度,支持固定频率、延迟执行及动态配置,结合线程池与异常处理可提升可靠性,适用于报表生成、健康检查等场景,助力企业级应用自动化。
|
机器学习/深度学习 人工智能 数据处理
《从“平”到“立”,3D集成技术如何重塑AI芯片能效版图》
3D集成技术正革新人工智能芯片的性能与能效。传统2D芯片设计受限于平面空间,信号传输延迟、能耗高;而3D集成通过垂直堆叠芯片层,大幅缩短信号路径,提升数据处理速度和计算密度,同时降低能耗并优化电源管理。它在数据中心和边缘设备中展现出巨大潜力,助力图像识别、语音处理等任务高效完成。尽管面临散热与成本挑战,但随着技术进步,3D集成有望成为AI芯片主流,推动人工智能更广泛的应用与创新。
389 0
|
10月前
|
人工智能 资源调度 算法
2025魔搭开发者大会 · 全景回顾
6月30日,2025魔搭开发者大会(ModelScope DevCon 2025)在北京海淀 · 香格里拉饭店圆满收官!
1049 0
|
机器学习/深度学习 人工智能 自然语言处理
【AI系统】知识蒸馏原理
本文深入解析知识蒸馏(Knowledge Distillation, KD),一种将大型教师模型的知识高效转移至小型学生模型的技术,旨在减少模型复杂度和计算开销,同时保持高性能。文章涵盖知识蒸馏的基本原理、不同类型的知识(如响应、特征、关系知识)、蒸馏方式(离线、在线、自蒸馏)及Hinton的经典算法,为读者提供全面的理解。
2619 2
【AI系统】知识蒸馏原理
|
JavaScript 前端开发
TS基础语法
TypeScript(缩写为TS)是一种静态类型的JavaScript超集,它为JavaScript添加了类型注解和其他扩展功能。下面是TypeScript的基础语法
|
机器学习/深度学习 并行计算 PyTorch
从零开始下载torch+cu(无痛版)
这篇文章提供了一个详细的无痛版教程,指导如何从零开始下载并配置支持CUDA的PyTorch GPU版本,包括查看Cuda版本、在官网检索下载包名、下载指定的torch、torchvision、torchaudio库,并在深度学习环境中安装和测试是否成功。
从零开始下载torch+cu(无痛版)