模型落地必备 | 南开大学提出CrossKD蒸馏方法,同时兼顾特征和预测级别的信息

简介: 模型落地必备 | 南开大学提出CrossKD蒸馏方法,同时兼顾特征和预测级别的信息

这篇论文提出了一种名为CrossKD的简单而有效的知识蒸馏方案,用于对目标检测模型进行模型压缩。现有的目标检测领域的最先进知识蒸馏方法主要基于特征蒸馏,一般被认为比预测蒸馏更好。然而,作者发现预测蒸馏的低效性主要是由于GT信号和蒸馏目标之间的优化目标不一致造成的。

为了解决这个问题,作者提出了CrossKD这种简单而有效的蒸馏方案,该方案将学生模型的检测Head的中间特征传递给教师模型的检测Head,并强制使得交叉Head的预测结果与教师模型的预测一致。通过这种蒸馏方式,学生模型的Head能够避免接收来自GT注释和教师模型预测的相互矛盾的监督信号,从而大大提高了学生模型的检测性能。

在MS COCO数据集上,仅应用预测蒸馏损失,作者的CrossKD将GFL ResNet-50模型的平均精度从40.2提升到43.7,超过了所有现有的目标检测领域的知识蒸馏方法。

1、简介

知识蒸馏(KD)作为一种模型压缩技术,在目标检测领域进行了深入研究,并且近年来取得了卓越的性能。根据蒸馏位置的不同,现有的知识蒸馏方法可以大致分为两类:

  • 预测蒸馏
  • 特征蒸馏

预测蒸馏(见图1(a))最早在[Distilling the knowledge in a neural network]中提出,指出教师的预测结果的平滑分布对学生的学习更有利,而不是GT值的Dirac分布。换句话说,预测蒸馏迫使学生的预测结果与教师的预测分布相似。

而特征蒸馏(见图1(b))则遵循了FitNet中提出的思想,认为中间特征包含的信息比教师的预测结果更多。它旨在强迫教师和学生之间的特征保持一致性。

预测蒸馏在蒸馏目标检测模型中起着至关重要的作用。然而,长期以来观察到预测蒸馏比特征蒸馏更低效。通常情况下,郑等人提出了一种名为局部化蒸馏(LD)的方法,通过传递定位知识来改进预测蒸馏,将预测蒸馏推向了一个新的水平。尽管与先进的特征蒸馏方法(例如PKD)相比仍有差距,但LD表明预测蒸馏具有传递任务特定知识的能力,这使得学生模型能够从与特征蒸馏不同的角度受益。这激发了作者进一步探索和改进预测蒸馏的动力。

通过观察,作者发现预测蒸馏需要应对GT目标和蒸馏目标之间的冲突,而这在先前的工作中被忽视了。当使用预测蒸馏来训练一个检测器时,学生模型的预测结果被同时迫使蒸馏GT目标和教师模型的预测结果。然而,教师模型预测的蒸馏目标通常与分配给学生模型的GT目标存在很大的差异。

如图2所示,教师模型在绿色圈出的区域中产生了不准确的类别概率,这与GT目标产生了冲突。因此,在蒸馏过程中,学生模型经历了一种矛盾的学习过程,作者认为这是阻碍预测蒸馏实现更高性能的主要原因。

在本文中,作者提出了一种新颖的交叉Head知识蒸馏方法,称为CrossKD,以缓解目标冲突问题。如图1(c)所示,作者建议将学生模型Head的中间特征输入到教师模型的Head,得到交叉Head预测。然后,可以在新的交叉Head预测和教师模型的原始预测之间进行知识蒸馏操作。

尽管CrossKD非常简单,但具有以下两个主要优势:

  • 首先,知识蒸馏损失不会影响学生模型Head的权重更新,避免了原始检测损失和知识蒸馏损失之间的冲突。
  • 此外,由于交叉Head预测和教师模型的预测都是通过共享部分教师模型的检测Head生成的,因此交叉Head预测与教师模型的预测相对一致。这减轻了教师-学生对之间的差异,增强了预测蒸馏的训练稳定性。

这两个优势使得作者的CrossKD能够高效地从教师模型的预测中提取知识,并且比先前最先进的特征蒸馏方法具有更好的性能。

作者的方法没有任何花里胡哨的东西,但可以显著提升学生模型检测器的性能,实现更快的训练收敛。本文在COCO数据集上进行了全面的实验,详细说明了CrossKD的有效性。

具体而言,仅应用预测蒸馏损失的情况下,CrossKD在1×训练计划下在GFL上达到了43.7 AP,比Baseline高出3.5 AP,超过了之前所有最先进的目标检测知识蒸馏方法。此外,实验证明作者的CrossKD与特征蒸馏方法是正交的。通过将CrossKD与最先进的特征蒸馏方法(例如PKD)相结合,作者在GFL上进一步实现了43.9 AP的性能。

2、本文方法

2.1 准备工作

在介绍作者的CrossKD之前,作者首先简要回顾两种主要类型的知识蒸馏方法:特征蒸馏和预测蒸馏。

1、特征蒸馏

特征蒸馏最早是在[FitNets]中提出的,旨在在教师和学生的潜在特征上强制实现一致性。其目标可以表示为以下公式:

在上述公式中,和分别表示学生和教师的中间特征,通常是指特征金字塔网络(Feature Pyramid Network,FPN)的特征。被引入来衡量和之间的距离,例如,在[FitNets]中使用均方误差(Mean Square Error),在[PKD]中使用皮尔逊相关系数(Pearson Correlation Coefficient)。

在上述描述中,表示区域选择原则,它在整个图像区域R中为每个位置生成一个权重。为了避免大幅度的噪声干扰模型的收敛性,不同的方法可能会使用不同的区域选择原则来选择用于蒸馏的有效区域,并平衡前景和背景样本的权重。最后,损失将通过整个中间特征上的的累积进行归一化处理。

特征蒸馏由于其出色的性能已经成为目标检测知识蒸馏方法的主流。然而,这可能强迫学生模型蒸馏教师模型中的不必要的噪声,这可能对最终结果产生负面影响。

2、预测模拟

与特征蒸馏不同,预测蒸馏旨在通过最小化教师和学生之间的预测差异来传递深层知识。其目标可以描述为:

其中,和分别是学生和教师检测Head生成的预测向量。区域选择原则根据不同的工作而有所不同。是计算和之间差异的损失函数,例如,用于分类的KL散度,用于回归的L1损失和LD。

正如在LD中描述的那样,由于预测具有明确的物理含义,预测蒸馏可以向学生提供特定任务的知识。然而,与特征蒸馏方法相比,预测蒸馏的性能较差限制了其应用。

2.2 Cross-Head Knowledge Distillation

正如在第1节中所述,作者观察到直接蒸馏教师的预测会面临目标冲突问题,这妨碍了预测蒸馏方法取得良好的性能。为了缓解这个问题,在本节中作者提出了一种新颖的交叉Head知识蒸馏(CrossKD)方法。总体框架如图3所示。与许多以前的预测蒸馏方法类似,作者的CrossKD在预测上进行蒸馏过程。不同的是,CrossKD将学生的中间特征传递给教师的检测Head,并生成交叉Head预测以进行蒸馏。

对于像RetinaNet这样的检测器,每个检测Head通常由一系列卷积层组成,表示为。为了简化起见,作者假设每个检测Head中总共有个卷积层(例如,在具有4个隐藏层和1个预测层的RetinaNet中是5个)。作者用来表示生成的特征图,表示的输入特征图。预测是由最后一个卷积层生成的。因此,对于给定的教师-学生对,教师和学生的预测可以分别表示为和。

除了教师和学生的原始预测之外,CrossKD还额外将学生的中间特征传递给教师的检测Head的第个卷积层,从而产生交叉Head预测。给定,作者建议使用交叉Head预测与教师的原始预测之间的KD loss来作为作者CrossKD的目标,具体描述如下:

其中,和分别表示区域选择原则和归一化因子。为了避免设计复杂的,作者遵循训练密集检测器的默认操作。在分类分支中,是一个恒定函数,其值为1。在回归分支中,是指示符,在前景区域生成1,在背景区域生成0。根据每个分支的不同任务(例如分类或回归),作者使用不同类型的来有效地将特定于任务的知识传递给学生。

通过进行CrossKD,检测损失和蒸馏损失分别应用于不同的分支。如图3所示,检测损失的梯度通过学生的整个Head传递,而蒸馏损失的梯度通过冻结的教师层传播到学生的潜在特征上,从而在启发式上增加了教师和学生之间的一致性。与直接对教师-学生对之间的预测进行调整相比,CrossKD允许学生的一部分检测Head仅与检测损失相关,从而更好地优化GT目标。作者在实验部分进行了定量分析。

2.3 优化目标

训练的总损失可以表示为检测损失和蒸馏损失的加权和,如下所示:

其中,和表示检测损失,它们是在学生的预测值、与对应的真实目标值、之间计算得出的。额外的CrossKD损失表示为和,它们是在交叉Head的预测值、与教师的预测值、之间进行计算的。

作者在不同的分支中应用不同的距离函数来传递特定任务的信息。在分类分支中,作者将教师预测的分类分数视为软标签,并直接使用GFL中提出的Quality Focal Loss(QFL)来拉近教师和学生之间的距离。

至于回归,密集检测器中主要有两种形式的回归方法:

  • 第一种回归形式直接从Anchor框(例如RetinaNet,ATSS)或点(例如FCOS)回归边界框。在这种情况下,作者直接使用GIoU作为Dpred。
  • 另一种情况是回归形式预测一个向量来表示框位置的分布(例如GFL),它包含比边界框表示的Dirac分布更丰富的信息。

为了有效地蒸馏位置分布知识,作者采用KL散度(如LD)来进行位置知识的转移。关于损失函数的更多细节请参考补充材料。

3、实验

3.1 Positions to apply CrossKD

3.2 CrossKD v.s. Feature Imitation

3.3 CrossKD v.s. Prediction Mimicking

3.4 CrossKD v.s. HEAD

3.5 CrossKD for Lightweight Detectors

3.6 Comparison with SOTA KD Methods

3.7 CrossKD on Different Detectors

4、参考

[1].CrossKD: Cross-Head Knowledge Distillation for Dense Object Detection.

相关文章
|
9月前
|
机器学习/深度学习 人工智能 JSON
知识蒸馏方法探究:Google Distilling Step-by-Step 论文深度分析
大型语言模型(LLM)的发展迅速,从简单对话系统进化到能执行复杂任务的先进模型。然而,这些模型的规模和计算需求呈指数级增长,给学术界和工业界带来了挑战。为解决这一问题,知识蒸馏技术应运而生,旨在将大型模型的知识转移给更小、更易管理的学生模型。Google Research 提出的“Distilling Step-by-Step”方法不仅减小了模型规模,还通过提取推理过程使学生模型在某些任务上超越教师模型。该方法通过多任务学习框架,训练学生模型同时预测标签和生成推理过程,从而实现更高效、更智能的小型化模型。这为资源有限的研究者和开发者提供了新的解决方案,推动了AI技术的普及与应用。
402 19
知识蒸馏方法探究:Google Distilling Step-by-Step 论文深度分析
|
机器学习/深度学习 人工智能 自然语言处理
一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
一文搞懂【知识蒸馏】【Knowledge Distillation】算法原理
|
机器学习/深度学习 人工智能 数据管理
文生图的基石CLIP模型的发展综述
CLIP(Contrastive Language-Image Pre-training)是OpenAI在2021年发布的多模态模型,用于学习文本-图像对的匹配。模型由文本和图像编码器组成,通过对比学习使匹配的输入对在向量空间中靠近,非匹配对远离。预训练后,CLIP被广泛应用于各种任务,如零样本分类和语义搜索。后续研究包括ALIGN、K-LITE、OpenCLIP、MetaCLIP和DFN,它们分别在数据规模、知识增强、性能缩放和数据过滤等方面进行了改进和扩展,促进了多模态AI的发展。
2224 0
|
自然语言处理 算法 数据挖掘
自蒸馏:一种简单高效的优化方式
背景知识蒸馏(knowledge distillation)指的是将预训练好的教师模型的知识通过蒸馏的方式迁移至学生模型,一般来说,教师模型会比学生模型网络容量更大,模型结构更复杂。对于学生而言,主要增益信息来自于更强的模型产出的带有更多可信信息的soft_label。例如下右图中,两个“2”对应的hard_label都是一样的,即0-9分类中,仅“2”类别对应概率为1.0,而soft_label
自蒸馏:一种简单高效的优化方式
|
机器学习/深度学习 存储 计算机视觉
【论文速递】TPAMI2022 - 自蒸馏:迈向高效紧凑的神经网络
【论文速递】TPAMI2022 - 自蒸馏:迈向高效紧凑的神经网络
|
数据可视化 计算机视觉
使用MMDetection进行目标检测
本文介绍了如何使用MMDetection进行目标检测。首先需按官方文档安装MMDetection,不熟悉的同学可参考提供的教程链接。安装完成后,只需准备模型配置文件、模型文件及待检测的图片或视频。示例代码展示了如何加载模型并进行图像检测,最后通过可视化展示检测结果,包括类别和置信度。
357 1
使用MMDetection进行目标检测
|
12月前
|
传感器 安全 Linux
linux为什么不是实时操作系统
标准Linux内核并不是实时操作系统,因为它在任务调度、中断处理和内核抢占方面无法提供严格的时间确定性。然而,通过使用PREEMPT_RT补丁、Xenomai等实时扩展,可以增强Linux的实时性能,使其适用于某些实时应用场景。在选择操作系统时,需要根据具体应用的实时性要求,综合考虑系统的性能和可靠性。
301 1
|
机器学习/深度学习 监控 算法
【论文速递】CVPR2021 - 通过解耦特征的目标检测知识蒸馏
【论文速递】CVPR2021 - 通过解耦特征的目标检测知识蒸馏
|
12月前
|
机器学习/深度学习 存储 算法
基于Actor-Critic(A2C)强化学习的四旋翼无人机飞行控制系统matlab仿真
基于Actor-Critic强化学习的四旋翼无人机飞行控制系统,通过构建策略网络和价值网络学习最优控制策略。MATLAB 2022a仿真结果显示,该方法在复杂环境中表现出色。核心代码包括加载训练好的模型、设置仿真参数、运行仿真并绘制结果图表。仿真操作步骤可参考配套视频。
385 0
|
监控 数据挖掘 数据安全/隐私保护
ERP系统中的固定资产管理
【7月更文挑战第25天】 ERP系统中的固定资产管理
545 2