如何用YOLOv5玩转半监督(附源码实现)

本文涉及的产品
文档翻译,文档翻译 1千页
文本翻译,文本翻译 100万字符
图片翻译,图片翻译 100张
简介: 如何用YOLOv5玩转半监督(附源码实现)

431c8584d2c5e10894f587832b0ce93b.png

Paper: https://arxiv.org/pdf/2211.02213.pdf


Github: https://github.com/hnuzhy/SSDA-YOLO


引言

Domain Adaptive Object Detection, DAOD, 即域自适应对象检测,旨在减轻由跨域问题而引起的模型泛化性能下降。现有的大部分 DAOD 方法都比较老旧,几乎都是以两阶段的 Faster R-CNN 算法实现的,其计算量非常大导致耗时严重。因此,今天为大家带来一篇全新的 Semisupervised Domain Adaptive YOLO, SSDA-YOLO, 即基于半监督域自适应的 YOLO 方法,通过将最火爆的单阶段目标检测器 YOLOv5 与域自适应相结合从而提高跨域检测的性能。


具体做法分三步:


将知识蒸馏框架与 Mean Teacher 模型结合起来,以帮助学生模型获得未标记目标域的实例级特征;

利用场景风格迁移在不同领域交叉生成伪图像,以弥补图像层次差异;

提出一种更加直观的一致性损失来进一步对齐跨域预测。

通过最终的实验表明,本文方法在包括 PascalVOC、Clipart1k、Cityscapes 和 Foggy Cityscapes 在内的四个公共基准上均取得了不错的效果。此外,为了验证其泛化性,作者还收集了真实场景下的检测数据进行了评估验证,结果表明 SSDA-YOLO 在这些 DAOD 任务中有了相当大的改进,充分揭示了所提出的自适应模块的有效性以及在 DAOD 中应用更先进检测器的必要性。


背景

目标检测

目标检测算法大致可以分为三种:


基于单阶段的目标检测器如 YOLO v2-v8、SSD、Retinanet等;

基于双阶段的目标检测器如 RCNN、Fast-RCNN、Faster-RCNN等;

基于 VIsion Transformer 的目标检测器如 Relation Net、DETR等;

当然,也可以大致分为 Anchor-Based 和 Anchor-Free,不过这并不是本文讨论的重点,笔者之前也总结过两篇关于目标检测系列的总结性文章,有兴趣的读者可以看看:

ad986140fba141674bd2e3f5d8a9b131.png

1193082cc8dce1e6654e2a322b1bc32c.png

域自适应

域自适应学习(Domain Adaptation Learning)能够有效地解决训练样本和测试样本概率分布不一致的学习问题,是当前机器学习的热点研究领域,在计算机视觉、自然语言处理,文本分析,生物信息学,跨语言分析,视频分析,情感分析和手写体识别等领域均有广泛应用。这块内容平常比较少讲,今天先简单的介绍下跨域目标检测和半监督域自适应两部分,后期有时间的话可以专门出一篇文章详细介绍 Domain Adaptation 这个方向,大家有兴趣的可以关注『CVHub』官方卫星号,敬请期待!


现有的跨域目标检测方法大都基于两阶段目标检测器 Faster R-CNN 实现的。


DA-Faster

edff322178d965eacc99823c8173cb4c.png


DA-Faster 属于开创性的工作,主要贡献是引入梯度反转层(GRL)并首先设计实例级和图像级对齐以提升在未知域的性能。


SWDA

f6909fb7d91e793dfea07f44031c977c.png


SWDA 则提出了相似的强局部和弱全局特征对齐以进一步改善了 DA-Faster 的性能。


SCL

d7518ff327b28d3ff9036c118ea6f659.png


SCL 同样也是基于 Faster R-CNN,同时提出了一种基于梯度分离的堆叠互补损失方法。

NLDA

538b8d5e9e14500e37fff4eda3d67dfd.png


NLDA 从鲁棒学习的角度解决了域适应问题,同时将其表述为带有噪声标签的训练。为此,作者提出了一个强大的对象检测框架,它对边界框类标签、位置和大小注释中的噪声均具有鲁棒性。最后,为了适应域转移的问题,使用一组噪声目标边界框在目标域上训练模型,这些边界框是由仅在源域中训练的检测模型获得的。


MEAA

e80122f07716cfca2d1ca2375de799c7.png


MEAA 提出了一种由局部不确定性注意力对齐模块 LUAA 和多级不确定性感知上下文对齐 MUCA 模块所组成的双阶段目标检测算法以更好的解决跨域目标检测的问题。


UMT

20dcf6c5bb939283fc85e43c1217e835.png


UMT 则提出了一种新的用于跨域对象检测的无偏均值教师模型。其揭示了在跨域场景中简单均值教师模型通常存在相当大的模型偏差,并通过几种简单但高效的策略消除了模型偏差。特别的,对于教师模型,作者提出了一种用于 MT 的跨域蒸馏方法,以最大限度地利用教师模型的专业知识。此外,对于学生模型,通过增加具有像素级自适应的训练样本来减轻其偏差。最后,对于教学过程,则是采用分布外估计策略来选择最适合当前模型的样本,以进一步增强跨域蒸馏过程。


MSDA

d3e1ffff48fe73fa4cc3b8a52ef88491.png


MSDA 着眼于来自多个源域的标记数据,提出了分而治之的主轴网络 DMSN 来增强域不变性并保持判别力。


USDAF

74d5e18a78280d885d0bf9b08313b319.png


USDAF 实现了具有多标签学习的通用尺度感知域自适应 Faster R-CNN,以减少训练期间的负迁移效应。


SIGMA

233ef649b838c30dc74357721b025788.png


SIGMA 提出了一种新颖的语义完全图匹配框架,通过将源数据和目标数据表示为图形,并将自适应重新表述为图形匹配问题。简言之便是利用图节点建立语义感知节点亲和力,并利用图边作为结构感知匹配损失中的二次约束,通过节点到节点图匹配实现细粒度的自适应。


无监督域适应(UDA)被定义为使模型从标记的源域适应未标记的目标域,起初被广泛研究用于图像分类任务。


DTPL

af65a8d6360ff22d2ca96855cfce1ee6.png


DTPL 在 UDA 基础上通过提供目标域图像的图像级注释提出了一个弱监督的渐进域适应框架。


MTOR

0c9f93648f7fd6df72a462597763daea.png


MTOR 首先学习了关系图,分别捕捉教师和学生的区域对之间的相似性。然后通过三个一致性正则化优化整个架构:


区域级一致性以对齐教师和学生之间的区域级预测;

图间一致性以匹配教师和学生之间的图结构;

内部图一致性,以增强学生图中同一类区域之间的相似性。

除此之外,最近新提出的 TRKP、TDD 和 PT 则基本是通过应用 MT 模型应用知识蒸馏框架来补救跨域差异和感知目标相关特征。此外,用于处理跨域语义分割任务的 DAFormer 在其自训练 pipeline 中也采用了 MT 模型。


方法

受上述方法的启发,本文提出了一种新颖的半监督域自适应 SSDA-YOLO,通过知识蒸馏结构中构建方法,并集成了流行的 MT 模型。如下图所示,它在源数据集中利用监督学习,并在目标数据集中执行无监督学习。此外,未标记的目标训练图像在输入教师模型之前用类源全局场景进行样式转换。


55ba7d37b9246c4d3bc628cde34d8350.png


上图为 SSDA-YOLO 的整体架构图,从左到右分别为:

image.png

而 SSDA-YOLO 模型本身由四部分组成:

具有知识蒸馏框架的 Mean Teacher 模型;

用于指导稳健的学生网络更新、用于减轻图像级域差异的伪交叉生成训练图像;

用于补救跨域差异的更新蒸馏损失;

新颖的一致性方法损失用于进一步纠正跨域目标偏见误差。

下面详细描述下这四个部分。


Mean Teacher Model

Mean Teacher,即 MT 模型最初是由《Weight-averaged consistency targets improve semi-supervised deep learning result》一文提出并用于图像分类中的半监督学习。它由一个典型的知识蒸馏结构和两个相同的模型架构(即学生网络和教师网络)组成。对于域自适应任务,学生模型常使用梯度下降优化器在源域中使用标记数据进行训练。根据 MT 模型设置,教师模型由学生模型的指数移动平均,即 EMA 进行权重更新。


具体来说,假设学生和教师模型的权重参数分别记为image.png 那么我们便可以在每个训练批次步骤更新 Pt,具体公式如下:

b176fcc374e9933c5eb6516dff7ca5dc.png

其中 γ  是指数衰减,其理论值接近 1.0,通常设置在 9 的倍数范围内,即 0.99 和 0.999 等。


当将 MT 模型应用于本文的跨域目标检测任务时,可以将未标记的目标域样本image.png 设置为教师模型的单一输入。此外,作者还在这些未标记的样本 image.png上部分训练学生模型。在蒸馏过程中,通过从教师模型预测中选择具有高概率的边界框作为伪标签,学生模型倾向于减少目标域上的方差并增强模型的鲁棒性。假设我们有来自相同图像image.png的教师模型的增强目标输入image.png   ,则可以使用如下定义的蒸馏损失来惩罚两个模型之间预测的不一致性:

0527b1d97f9df9bc073605389fb50d09.png

image.png

Pseudo Training Images Generation

通过第一步我们成功的构建了一个最基础的蒸馏网络,不过遗憾的是,此处学生模型的权重更新主要由源域中的图像主导。相比之下,教师模型则不会接触到源图像并由目标域特征进行引导。所以我们要如何缓解这种图像级别的域差异呢?毕竟这样一来会导致两个模型偏向于过拟合单一的伪标签了。

9419ab9aedcc3b913097d313271ee86b.png


本文受 SWDA 的方法启发(上面介绍了),同样基于 CycleGAN 在全局场景级别通过弱对齐来学习域不变特征。在本文中,作者选择生成类目标伪造源图像和类源伪造目标图像来进行训练。如上图所示,这里采用更高级的未配对图像转换器 CUT 以实现更快、更稳健的场景传输,是不是整体看起来有点诡异的诙谐感~哈哈哈。


Remedying Cross-Domain Discrepancy

image.png6ee5accf69afa3c46168dbc89293835e.png

image.png

4c7f20e8cd2815b18332454b6b8ea9d4.png

上述关系是通过 MT 模型中的 EMA 参数更新建立的。如此一来,学到的教师模型将不会显着倾向于只擅长预测目标域中的对象。此外,学生模型的训练将逐渐接近真实的目标领域,而由于伪标签本身的监督较弱,image.png的过滤预测虽然不是那么准确,但这些伪标签在促进细粒度实例级适应方面发挥着不可替代的作用。


Consistency Loss Function

尽管输入学生模型的源类和目标类配对图像image.png 具有不同的场景级数据分布,但它们本质上属于相同的标签空间。理想情况下,一个合理的假设是,输入两个域图像的学生模型的输出应该是一致的。因此,为了保证它们的输出尽可能接近,我们可以在相应的两个分支上添加一个新的约束。直观上,我们有三种选择方案:


在相应的特征图之间应用中间监督策略;

在最终预测之间应用误差约束;

同时结合以上两种策略。

中间监督策略是由卷积姿态机 CPM 提出的,起初是用于单人姿态估计任务。其实这种方法跟深监督机制的思想是类似的,所以我们如果拿来应用在监督训练中解决梯度消失问题也是挺合理的嘛,你说是不是这个道理老铁?但问题是此处我们不希望模型输出相似的中间特征,而是期望其预测输出尽可能一致。因此,这种方案我们还是 pass 掉吧。


作者最终是采用第二种方式,即通过计算两个最终输出之间的 L2 距离来进行约束,公式可以表述如下:

016f91fd1b06b122696ebfa896d4b2ce.png

当然,这里我们也可以使用 L1 损失来替代,至于哪个损失更好,详见消融实验部分。一致性损失理论上是可以用于纠正客观性和分类的跨域偏差,作者后续也通过实验来证明其有效性。笔者早期也对常用的损失函数进行了全面性的总结,大家有时间的也可以捧个场:

49f42fe2beb0e5eebbef7dab7740d892.png

此外,在推理阶段,我们只需要采用经过精细训练的学生模型,并将目标图像作为单一输入。我们的模型可以通过联合优化所有相关损失来以端到端的方式进行训练,最终整体的损失函数如下所示:

19c79839cf10bfde0882a34c211ba7f3.png

看起来有点复杂,但其实大家把它拆成几部分单独理解也是蛮简单的,建议对照下代码去看。


实验

54e391e6723231271cec881cf0baecc4.png

如上述表格所示,本文选取了 11 种具有代表性的方法进行比较,有意思的是这些方法全是基于 Faster R-CNN。从实验结果看出,本文提出的域自适应模块效能好像不是很哇塞,不过 YOLOv5 的推理效率各方面还是挺不错的,对落地比较友好,貌似 YOLOv8 也快发版了哦,目前模型权重也放出来了,大家也可以去尝试下。

e4a213be959018bc7530c47649b3fa5b.png


上面除了 [17,16,19] 是基于 FCOS 实现的,其他都是基于 Faster R-CNN。上述表格报告了 DAOD 在 Foggy Cityscapes 验证集上的所有结果。可以看出,由于 YOLOv5 的数据增强策略,Source Only 方法实现了与最近最先进的方法如 EPMDA 相当的 mAP 值达到了35.9。通过添加蒸馏损失和一致性损失,本文方法在 BaseDC 更是达到了 55.9 的 mAP,远高于迄今为止 TDD 中的最佳结果 49.2。

f744766421ab65aa0582d8fabb58b9b8.png

从定性结果来看,尽管本文方法在 Rel 场景下表现不比 PT 和 TDD 好,不过大体还是优于同年提出的方法如 TIA、MGADA 和 SIGMA,大概率是得益于所提出的自适应策略的有效性以及 YOLOv5 出色的性能。

3163e46020ea869c87d78c9c60864fe7.png

这张图更有意思,是作者自己采集的真实场景下的图片,可以看到,尽管与 Oracle 结果相比仍然存在差距,但本文所提出的方法可以明显缓解真实课堂中跨域行为检测的准确性下降。消融实验部分这里不讲啦,明天还要上班有点累了,感兴趣的小伙伴自己去看看吧,今天先讲到这里。


总结

本文提出了一种名为 SSDA-YOLO 的新型半监督跨域目标检测方法。同以往大部分基于二阶段的目标检测器 Faster-RCNN 方法不同,本文采用更实用的 YOLOv5 作为基础的检测器。具体来说,这个框架包含三个有效的组件。


首先,基于知识蒸馏结构,我们分别学习作为学生网络的 YOLOv5 和基于教师网络的 Mean Teacher 模型,以构建稳健的训练。其次,通过执行风格转移以交叉生成伪标签训练图像以减轻全局域差异。

最后,应用一致性损失函数来校正来自不同域但具有相同标签的图像的预测偏移。


通过对公共基准和自制的打哈欠行为数据集进行的广泛实验证明,SSDA-YOLO 在实际跨域目标检测应用中的有效性和优越性,同时也揭示了采用先进检测器推进 DAOD 这个领域的必要性。


写在最后

如果您也对人工智能和计算机视觉全栈领域感兴趣,强烈推荐您关注有料、有趣、有爱的公众号『CVHub』,每日为大家带来精品原创、多领域、有深度的前沿科技论文解读及工业成熟解决方案!欢迎添加小编微信号:cv_huber,一起探讨更多有趣的话题!

目录
相关文章
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】26.卷积神经网络之AlexNet模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】26.卷积神经网络之AlexNet模型介绍及其Pytorch实现【含完整代码】
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】28.卷积神经网络之NiN模型介绍及其Pytorch实现【含完整代码】
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】25.卷积神经网络之LeNet模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】25.卷积神经网络之LeNet模型介绍及其Pytorch实现【含完整代码】
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】27.卷积神经网络之VGG11模型介绍及其Pytorch实现【含完整代码】
【从零开始学习深度学习】27.卷积神经网络之VGG11模型介绍及其Pytorch实现【含完整代码】
|
6月前
|
机器学习/深度学习 PyTorch 算法框架/工具
【从零开始学习深度学习】29.卷积神经网络之GoogLeNet模型介绍及用Pytorch实现GoogLeNet模型【含完整代码】
【从零开始学习深度学习】29.卷积神经网络之GoogLeNet模型介绍及用Pytorch实现GoogLeNet模型【含完整代码】
|
6月前
|
机器学习/深度学习 算法 PyTorch
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
【从零开始学习深度学习】45. Pytorch迁移学习微调方法实战:使用微调技术进行2分类图片热狗识别模型训练【含源码与数据集】
|
机器学习/深度学习 编解码 算法
经典神经网络论文超详细解读(四)——InceptionV2-V3学习笔记(翻译+精读+代码复现)
经典神经网络论文超详细解读(四)——InceptionV2-V3学习笔记(翻译+精读+代码复现)
209 0
经典神经网络论文超详细解读(四)——InceptionV2-V3学习笔记(翻译+精读+代码复现)
|
XML 存储 PyTorch
基于Pytorch的从零开始的目标检测 | 附源码
基于Pytorch的从零开始的目标检测 | 附源码
|
7月前
|
机器学习/深度学习 算法 关系型数据库
【PyTorch深度强化学习】DDPG算法的讲解及实战(超详细 附源码)
【PyTorch深度强化学习】DDPG算法的讲解及实战(超详细 附源码)
2110 1
|
7月前
|
机器学习/深度学习 数据采集 TensorFlow
【Tensorflow深度学习】实现手写字体识别、预测实战(附源码和数据集 超详细)
【Tensorflow深度学习】实现手写字体识别、预测实战(附源码和数据集 超详细)
209 1