目标识别知识蒸馏

简介: 翻译:《learning efficient object detection models with knowledge distillation》

learning efficient object detection models with knowledge distillation

尽管基于CNN的目标识别展现出了明显的准确率上的提升,但它也需要大量的时间来处理一个输入图像,尤其是在实时应用上。目前最先进的模型经常使用非常深的网络,包含大量的float类型的点操作。模型压缩学习等方法可以减少所需的参数量,但是也带来了准确率上的损耗。在这篇论文中,我们提出了一个新的紧凑且快速的目标识别网络,同时使用了知识蒸馏和hint learning提升了模型准确率。虽然知识蒸馏可以给简单的分类过程带来显著的提升,目标识别的复杂性不可避免地带来了新的挑战。我们用一些创新方法解决了这个问题,比如说用加权交叉熵损失来解决类别不平衡的问题,用有限的教师损失来使得回归成分和适应层从教师分布中更好地学习。我们在不同的数据集上用不同的蒸馏配置进行测试。我们的显示出在多类别识别模型上的准确度-速度权衡方面的一致的提升。

Introduction

最近几年,依赖cnn的目标识别模型在准确率上有了极大的提升。这使得视觉目标识别在从监控到自动驾驶领域都成为一个诱人的可能。然而,在很多应用中,速度都是主要需求,而这个需求和对准确率的要求是相矛盾的。因此,如果目标识别的提升依赖于更深的网络结构,那么它在计算时间上的要求也会更多。但是我们也都知道,深度神经网络为了帮助泛化,总是参数过量的。因此,为了实现更快的速度,一些工作开始探究新的类似于全卷积网络或者轻量模型和小滤波器等结构。但是即使已经获得了速度的进步,距离实时还是有点差距,还需要更仔细的设计和调优。

更深的网络在训练后往往有很好的表现,因为它们有足够好的网络能力。目标识别这类任务或许并不是很需要这种模型能力。一些图像分类工作使用了模型压缩,即分解每一层的权重,然后逐层重建或精调来恢复一些准确率。这个方法可以实现明显的速度上的提升,但是原始模型和压缩后的模型的准确率上仍然有一些差距,这个差距在对复杂的类似目标识别的任务使用模型压缩时会更大。在另一方面,在知识蒸馏方面的开创性工作表面,一个被训练来模型更深或者更复杂模型的浅层模型或者压缩模型能够恢复一些或者全部的准确率损失。然而,这些结果都只适用于使用了更简单的没有正则化的网络的分类问题。

在多类别目标识别任务上使用蒸馏技术,出于某些原因仍然很有挑战性。首先,识别模型的表现会因为模型压缩降级,因为识别任务的标签通常复杂且冗长。其次,知识蒸馏本身就是为每个类别一样重要的分类问题提出的,然而在识别问题中背景占了大多数。第三,识别任务更复杂,既包含了类别元素也包含了边界框回归。最后,一个额外的挑战是我们关注在相同领域内的知识转移, 不需要额外的数据和标签,不像其他需要来自别的域的数据的工作。

为了解决以上的挑战,我们提出了一个用知识蒸馏为目标识别任务训练快速模型的方法,我们的贡献包括四个方面:

  1. 我们为通过知识蒸馏学习多类别目标识别模型提出了一个端对端的可训练的框架。就我们所知,这是第一次成功使用知识蒸馏来解决多类别目标识别问题。
  2. 我们提出一个可以解决上述问题的新的损失函数。特别的,我们提出了一个加权交叉熵损失,解决了类别不平衡导致的对背景类别错误分类的影响。我们还为知识蒸馏提出了一个教师有限回归损失,为适应层提出了hint learning方法,这样学生可以更好地学习教师中间层的神经元的分布。
  3. 我们使用多种大规模的公开benchmark进行检验。我们的每个设计改进都展现了正面的影响,并且在所有的benchmarks上都有目标识别的准确率的明显提升。
  4. 我们将框架和泛化和欠拟合问题结合起来,提出了一些观点。

Method

在我们的研究中,我们使用Faster-RCNN作为我们的目标识别框架。Faster-RCNN由三个模块组成:1. 通过卷积层共享特征提取;2. 生成目标选取结果的一个RPN网络;3. 一个返回识别分数和空间调节向量的分类和回归网络。RCN和RPN都用1的输出作为特征,RCN也把RPN的结果作为输入。为了实现更为准确的目标识别结果,对三个组件都学习strong的模型是很重中之重。

Overall Structure

我们通过使用高性能的教师识别网络中的知识来学习一个strong且高效的学生目标识别器。我们的整体学习框架在图1中显示。首先,我们采用一个基于hint的学习方法来鼓励学生网络的特征表达。然后我们使用知识蒸馏框架在RPN和RCN学习一个更strong的分类模型。为了解决目标识别中类别不平衡的问题,我们在蒸馏框架中使用加权交叉熵损失。最后,我们将教师模型的回归输出以上限的形式转移,如果学生模型的回归输出比教师的更好,则不需要使用别的损失。

我们的整体学习目标可以写成以下的形式:

$$ L_{RCN}=\frac{1}{N}\sum_iL_{cls}^{RCN}+\lambda\frac{1}{N}\sum_jL_{reg}^{RCN} $$
$$ L_{RPN}=\frac{1}{M}\sum_iL_{cls}^{RPN}+\lambda\frac{1}{M}\sum_jL_{reg}^{RPN} $$
$$ L=L_{RPN}+L_{RCN}+\gamma L_{Hint} $$

$N$是RCN的批大小,$M$是RPN的批大小。$L_{cls}$定义了分类的损失函数,这个函数组合了硬softmax损失呵软知识蒸馏损失。$L_{reg}$是边界框回归损失,这个损失组合了平滑的L1损失和我们新提出的教师有限L2回归损失。最终,$L_{hint}$定义了基于hint的损失函数,鼓励学生模型去模仿老师的特征响应。$\lambda$和$\gamma$是控制不同损失间平衡的超参数,我们把它们分别固定在1和0.5。

Knowledge Distillation for Classification with Imbalanced Classes

为了训练分类网络,传统的知识蒸馏被提出,教师网络的预测被用来指导学生模型的训练。假设我们有数据集${x_i,y_i},i=1,2,...,n$,$x_i$是输入图像,$y_i$是类别标签。$t$是教师模型,$P_t=softmax(\frac{Z_t}{T})$是它的预测,$Z_t$是最终的输出分数。$T$是一个温度参数。相似的,我们为学生模型$s$定义$P_s=softmax(\frac{Z_s}{T})$。学生模型的训练要最优化以下损失函数:
$$ L_{cls}=\mu L_{hard}(P_s,y)+(1-\mu)L_{soft}(P_s,P_t) $$
$L_{hard}$是用Faster-RCNN使用的真实标签得到的损失,$L_{soft}$是使用教师的预测得到的软损失,$\mu$是用来平衡两个损失的参数。一个深度教师能够更好的拟合训练数据,并在测试时表现更好。软标签包含了教师模型发现的不同类别之间关系的信息。通过从软标签中学习,学生网络可以继承这样的隐藏信息。

硬损失和软损失都是交叉熵损失。不像简单的分类问题,识别问题需要处理严重的类别不平衡,这个问题主要是由背景部分造成的。在图像分类中,唯一可能的错误就是对前景的分类错误,而在目标识别中,没有区分前景和后景主导了误差部分,因为错误分类前景的概率很小。为了解决这个问题,我们采用了类别加权的交叉熵作为蒸馏损失。
$$ L_{soft}(P_s,P_t)=-\sum w_cP_tlogP_s $$
我们对背景使用更大的权重,对其他类别使用相对小的权重。比如我们对背景使用$w_0=1.5$,对其他的使用$w_i=1$。

挡$P_t$和硬标签很接近的时候,一个类别的概率会非常接近1,别的类别的概率会非常接近0。温度参数$T$是引入来平滑这个输出的。使用更高的温度将会促使产生更软的标签,所以接近0的概率值也不会呗损失函数忽略掉。这个对于越简单的任务越相关。但是对于较为困难的问题,在预测误差本来就比较大的情况下,一个较大的$T$反而会引进更多的不利于学习的noise。因此在更大的数据集中,通常会在分类时使用稍微小一点的$T$。对于更困难的类似目标识别的问题,我们发现在蒸馏损失中不使用温度参数时实际上效果最好。

Knowledge Distillation for Regression with Teacher Bounds

除了分类层,大部分基于CNN的 目标识别器也使用边界框回归来调整输入选取的位置和大小。学习一个好的回归模型对确保目标识别准确率至关重要。不像对离散的类别的蒸馏,教师模型的回归输出能够提供非常错误的指导,在真实值回归不是受限的时候。此外,教师模型可能会提供和真实方向相矛盾的回归方向。因此,不同于直接使用教师的回归输出作为目标,我们把它作为学生模型需要实现的一个上界。通常学生模型的回归向量需要和真实标签尽可能接近,但是如果学生模型的质量超过教师模型,我们就不给学生模型提供额外的loss啦。我们把这称为教师有限回归loss,$L_b$,被用来构成回归损失$L_{reg}$。
$$ L_b(R_s,R_t,y)=\left\{ \begin{array}{rcl} |R_s-y|^2_2,\text{if} |R_s-y|^2_2+m>|R_t-y|^2_2\\ 0,\text{otherwise}\\ \end{array} \right.\\ L_{reg}=L_{sL1}(R_s,y_{reg})+vL_b(R_s,R_t,y_{reg}) $$
$m$是一个边界,$y_{reg}$定义了回归真实标签,$R_s$是学生模型的回归输出。$R_t$是教师模型的预测,$v$是权重参数。$L_{sL1}$是平滑的L1损失。教师有限回归损失$L_b$只在学生模型的误差比教师模型的大时进行惩罚。虽然我们用了L2,但是别的L1或则平滑的L1也能用。我们的组合的损失鼓励学生在去接近甚至超过教师。但是在学生模型达到教师模型的表现后就不会继续push。

Hint Learning with Feature Adaptation

蒸馏在转移知识时,只需用最终输出。使用教师模型的中间表达作为hint有助于训练过程并且能够提升学生的最终表现。使用特征向量$V$和$Z$的L2距离。
$$ L_{Hint}(V,Z)=|V-Z|_2^2 $$
$Z$代表我们选来作为hint的中间层。$V$代表学生网络中被指导的层的输出。

使用hint learning时,学生和老师模型对应的层的神经元个数需要相同。为了匹配hint层和被指导层的通道数目,我们在被指导层后面加一个adaptation。这个适应层匹配神经元的scale使得学生模型的特征和老师模型的接近。当hint层和被引导层都是全连接层时,一个全连接层被用来作为适应层。当hint层和被指导层都是卷积层,我们可以使用1x1卷积来节省内存。我们发现使用适应层对实现高效率的知识转移是很重要的,即使两个层的通道数都一样。适应层也能够在两个层的特征格式不同时匹配这种不同。当hint层和被指导层都是卷积层时且分辨率不同时,我们用padding来匹配输出的数目。

相关文章
|
10天前
|
机器学习/深度学习 自然语言处理 数据可视化
深度探索变分自编码器在无监督特征学习中的应用
【4月更文挑战第20天】 在深度学习领域,无监督学习一直是研究的热点问题之一。本文聚焦于一种前沿的生成模型——变分自编码器(Variational Autoencoder, VAE),探讨其在无监督特征学习中的关键作用与应用潜力。不同于传统的摘要形式,本文将直接深入VAE的核心机制,分析其如何通过引入随机隐变量和重参数化技巧,实现对复杂数据分布的有效建模。文章还将展示VAE在多个实际数据集上的应用结果,验证其作为无监督特征提取工具的有效性和普适性。通过理论与实践的结合,本文旨在为读者提供关于VAE在无监督特征学习领域的全面认识。
|
9月前
|
机器学习/深度学习 计算机视觉
【图像分类】基于LIME的CNN 图像分类研究(Matlab代码实现)
【图像分类】基于LIME的CNN 图像分类研究(Matlab代码实现)
|
9月前
|
自然语言处理 算法 测试技术
PointGPT 论文解读,点云的自回归生成预训练
PointGPT 论文解读,点云的自回归生成预训练
370 0
|
10月前
|
机器学习/深度学习 存储 数据采集
使用深度神经网络对肿瘤图像进行分类
使用 Inception-v3 深度神经网络对可能不适合内存的多分辨率全玻片图像 (WSI) 进行分类。 用于肿瘤分类的深度学习方法依赖于数字病理学,其中整个组织切片被成像和数字化。生成的 WSI 具有高分辨率,大约为 200,000 x 100,000 像素。WSI 通常以多分辨率格式存储,以促进图像的高效显示、导航和处理。
91 0
|
10月前
|
传感器 机器学习/深度学习 数据采集
使用PointNet深度学习进行点云分类
训练 PointNet 网络以进行点云分类。 点云数据由各种传感器获取,例如激光雷达、雷达和深度摄像头。这些传感器捕获场景中物体的3D位置信息,这对于自动驾驶和增强现实中的许多应用非常有用。例如,区分车辆和行人对于规划自动驾驶汽车的路径至关重要。然而,由于每个对象的数据稀疏性、对象遮挡和传感器噪声,使用点云数据训练稳健分类器具有挑战性。深度学习技术已被证明可以通过直接从点云数据中学习强大的特征表示来解决其中的许多挑战。点云分类的开创性深度学习技术之一是PointNet。
578 0
|
机器学习/深度学习 编解码 移动开发
【论文解读】——基于多尺度卷积网络的遥感目标检测研究(姚群力,胡显,雷宏)
【论文解读】——基于多尺度卷积网络的遥感目标检测研究(姚群力,胡显,雷宏)
【论文解读】——基于多尺度卷积网络的遥感目标检测研究(姚群力,胡显,雷宏)
|
11月前
|
计算机视觉
【目标检测出】评价指标
【目标检测出】评价指标
101 0
|
机器学习/深度学习 算法
【文本分类】《融合后验概率校准训练的文本分类算法》
【文本分类】《融合后验概率校准训练的文本分类算法》
【文本分类】《融合后验概率校准训练的文本分类算法》
|
算法 开发工具 git
CenterNet+ deepsort实现多目标跟踪
CenterNet+ deepsort实现多目标跟踪
CenterNet+ deepsort实现多目标跟踪
|
机器学习/深度学习 存储 编解码
NeRF神经网络介绍
计算机视觉5D技术,NeRF神经网络是一个简单的完全连接的网络,它被训练为使用渲染损失来再现单个场景的输入视图。
1053 0
NeRF神经网络介绍