与主流的feature蒸馏方法不同,本研究将重心放回到logits蒸馏上,提出了一种新的方法「解耦知识蒸馏」,重新达到了SOTA结果,为保证复现该研究还提供了开源的蒸馏代码库:MDistiller。
1 研究摘要
近年来顶会的 SOTA 蒸馏方法多基于 CNN 的中间层特征,而基于输出 logits 的方法被严重忽视了。饮水思源,本文中来自旷视科技 (Megvii)、早稻田大学、清华大学的研究者将研究重心放回到 logits 蒸馏上,对 7 年前 Hinton 提出的知识蒸馏方法(Knowledge Distillation,下文简称 KD)[1] 进行了解耦和分析,发现了一些限制 KD 性能的重要因素,进而提出了一种新的方法「解耦知识蒸馏」(Decoupled Knowledge Distillation,下文简称 DKD)[2],使得 logits 蒸馏重回 SOTA 行列。
同时,为了保证复现和支持进一步研究,该研究提供了一个全新的开源代码库 MDistiller,该库涵盖了 DKD 和大部分的 SOTA 方法,并不断进行更新维护,欢迎大家试用并提供宝贵的反馈意见。
2 研究动机
上图是大家熟知的 KD 方法,KD 用 Teacher 网络和 Student 网络的输出 logits 来计算 KL Loss,从而实现 dark knowledge 的传递,利用 Teacher 已经学到的知识帮助 Student 收敛得更好。在 KD 之后,更多的基于中间特征的蒸馏方法不断涌现,不断刷新知识蒸馏的 SOTA。但该研究认为,KD 这样的 logits 蒸馏方法具备两点好处:
1. 基于 feature 的蒸馏方法需要更多复杂的结构来拉齐特征的尺度和网络的表示能力,而 logits 蒸馏方法更简单高效;2. 相比中间 feature,logits 的语义信息是更 high-level 且更明确的,基于 logits 信号的蒸馏方法也应该具备更高的性能上限,因此,对 logits 蒸馏进行更多的探索是有意义的。
该研究尝试一种拆解的方法来更深入地分析 KD:将 logits 分成两个部分(如图),蓝色部分代表目标类别(target class)的 score,绿色部分代表非目标类别(Non-target class)的 score。这样的拆解使得我们可以重新推导 KD 的 Loss 公式,得到一个新的等价表达式,进而做更多的实验和分析。
2.1 符号定义
这里只写出关键符号定义,更具体的定义请参考论文正文。
首先,该研究将第 i 类的分类概率表示为(其中表示网络输出的 logits):
为了拆解分类网络输出的 logits,该研究接下来定义了两种新的概率分布:
1. 目标类 vs 非目标类的二分类分布,该概率分布和分类监督信号高度耦合。该分布包含两个元素:目标类概率和全部非目标类概率,分别表示为:
2. 非目标类内部竞争的多分类分布,也就是在预测样本为非目标类的前提下每个类各自的概率(总和为 1)。这个概率分布和分类的监督信号是不相关的,换句话说,从这个概率分布中无法得知目标类上的预测置信度,其表达式为:
根据上述定义,可以得到一个显然的数学关系:。这些定义和数学关系将帮助我们得到 KD Loss 的一个新的表达形式。
2.2 重新推导 KD Loss
首先,KD 的 Loss 定义如下:
然后根据公式(1)和(2),我们可以将其改写为:
可以观察到,式中的第一项只牵涉到了目标类别 vs 非目标类别的二分类概率分布,第二项牵涉到了非目标类概率分布的 KL 散度和权重。该研究将第一项命名为目标类别知识蒸馏 Target Class Knowledge Distillation(下文简称 TCKD),将第二项中的 KL 散度命名为非目标类别知识蒸馏 Non-target Class Knowledge Distillation(下文简称 NCKD)。至此,该研究完成了对 KD Loss 的拆分,将其分成了两个可单独使用的部分,并可以分析其各自的作用:
3 启发式探索
首先,该研究对 TCKD 和 NCKD 做了消融实验,观察它们对蒸馏性能的影响;接着,他们分别探索 TCKD 和 NCKD 的作用;最后,研究者做了一些启发式的讨论。
3.1 单独使用 TCKD/NCKD 训练
如表 1 所示,我们可以观察到:
1. 同时使用 TCKD 和 NCKD(等同于 KD),有不错的性能提升;2. 单独使用 TCKD 进行蒸馏,会对蒸馏效果产生较大的损害(这一点在补充材料中有详细讨论,主要和蒸馏温度 T 相关);3. 单独使用 NCKD 进行蒸馏,和 KD 的效果是差不多的,甚至有时会更好;
基于这些观察可以推出两个初步结论:
1.TCKD 是没用的,甚至在单独使用时可能是有害的;2.NCKD 可能是 KD 生效的主要原因;
接下来该研究就这两个初步的结论进行了进一步的分析。
3.2 TCKD:传递样本难度相关的知识
TCKD 作用于目标类的二分类概率分布上,这个概率的物理含义是「网络对样本的置信度」。比如:如果一个样本被 Teacher 学会了,会产生类似[0.99, 0.01] 的 binary 概率,而如果一个样本比较难拟合,则会产生类似 [0.6, 0.4] 的 binary 概率。所以该研究猜测:TCKD 传递了和样本拟合难度相关的知识,当训练集拟合难度高时才会起到作用。为了证明这一点,该研究设计了三组实验来增加 CIFAR-100 的训练难度,观察 TCKD 是否有效:
更强的数据增广:
以表 2 中的 ShuffleNet-V1 为例,在使用 AutoAugment 的情况下,训练集难度有了明显提升,此时仅仅使用 NCKD 只能达到 73.8% 的 student 准确率,而同时使用 TCKD 和 NCKD 可以将 student 准确率提升至 75.3%。
更 Noisy 的标签:
表 3 中,该研究通过控制 noisy ratio 对数据集的标签引入不同程度噪声,ratio 越大表示噪声越大。可以看到,随着数据集的噪声变大,单独使用 NCKD 的效果变得越来越差,同时引入 TCKD 的增益也越来越大。说明在越难学的数据上,TCKD 的作用就会越明显。
更难的数据集:
ImageNet 是一个比 CIFAR-100 更困难的数据集,所以该研究在 ImageNet 上也进行了尝试。从表 4 可以看出,在 ImageNet 上只使用 NCKD 的效果也是没有同时使用 TCKD 和 NCKD 要好的。
总结
三组实验都反映出,当训练数据拟合难度变高时(无论是数据本身难度、还是噪声和增广带来的难度),TCKD 能提供更有效的知识,对蒸馏性能的提升也越高,这些实验在一定程度上说明了 TCKD 确实是在传递有关样本拟合难度的知识,印证了该研究的想法。
3.3 NCKD:被抑制的重要成分
表 1 中反映出的另一个有趣的现象是:只使用 NCKD 也能取得令人满意的蒸馏效果,甚至可能比 KD 更好。这样的现象反映出:非目标类别上的 logits 中蕴含的信息,才是最主要的 dark knowledge 成分。
然而当回顾 KD 的新表达式时,发现 NCKD 对应的 loss 是和权重耦合在一起的。换言之,如果 teacher 网络的预测越置信,NCKD 的 loss 权重就更低,其作用就会越小。而该研究认为,teacher 更置信的样本能够提供更有益的 dark knowledge,和 NCKD 耦合的权重会严重抑制高置信度样本的知识迁移,使得知识蒸馏的效率大幅降低。为了证明这一点,该研究做了如下实验:
1. 依据 teacher 模型的置信度,该研究对训练集上的样本做了排序,并将排序后的样本分成置信(置信度 top-50%)和非置信 (剩余) 两个批次;2. 训练时,对全部样本使用分类 Loss,并只对置信批次 / 非置信批次使用 NCKD Loss;
实验结果如表 5 所示,0-50% 表示置信批次,50-100% 表示非置信批次。第一行是在整个训练集上做 NCKD 的结果,第二行表示只对置信批次做 NCKD,第三行表示只对非置信批次做 NCKD。显然,置信批次上使用 NCKD 带来了更主要的涨点,说明置信度更高的样本对蒸馏的训练过程是更有益的,因此是不应该被抑制的。
3.4 启发
至此,该研究完成了对 KD Loss 的解耦,并且分析了两个部分各自的作用。所有结果都表明,TCKD 和 NCKD 都有自己的重要作用,然而,研究注意到了在原始的 KD Loss 中,TCKD 和 NCKD 是存在不合理的耦合的:
1. 一方面,NCKD 和耦合,会导致高置信度样本的蒸馏效果大打折扣;2. 另一方面,TCKD 和 NCKD 是耦合的。然而这两个部分传递的知识是不同的,这样的耦合导致了他们各自的重要性没有办法灵活调整。
4 Decoupled Knowledge Distillation
根据推导和启发式探索,该研究提出了一种新的 logits 蒸馏方法“解耦知识蒸馏(DKD)”,来解决上一章提出的两个问题,如上图所示。DKD 的 Loss 表达式如下:
和 KD Loss 相比,该研究将限制 NCKD 的权重替换为了,并给 TCKD 设置了一个权重。DKD 可以很好地解决刚才提到的两个问题:一方面,TCKD 和 NCKD 被解耦,它们各自的重要性可以独立调节;另一方面,对于蒸馏更重要的 NCKD 也不会再被 Teacher 产生的高置信度抑制,大大提高了蒸馏的灵活性和有效性。DKD 的伪代码如下: