Focal Loss升级 | E-Focal Loss让Focal Loss动态化,类别极端不平衡也可以轻松解决(一)

简介: Focal Loss升级 | E-Focal Loss让Focal Loss动态化,类别极端不平衡也可以轻松解决(一)

1简介


长尾目标检测是一项具有挑战性的任务,近年来越来越受到关注。在长尾场景中,数据通常带有一个Zipfian分布(例如LVIS),其中有几个头类包含大量的实例,并主导了训练过程。相比之下,大量的尾类缺乏实例,因此表现不佳。长尾目标检测的常用解决方案是数据重采样、解耦训练和损失重加权。尽管在缓解长尾不平衡问题方面取得了成功,但几乎所有的长尾物体检测器都是基于R-CNN推广的两阶段方法开发的。在实践中,一阶段检测器比两阶段检测器更适合于现实场景,因为它们计算效率高且易于部署。然而,在这方面还没有相关的工作。

与包含区域建议网络(RPN)的两阶段方法,在将建议提供给最终的分类器之前过滤掉大多数背景样本相比,一阶段检测器直接检测规则的、密集的候选位置集上的目标。如图1所示,由于密集的预测模式,在一阶段检测器中引入了极端的前景-背景不平衡问题。结合长尾场景下的前景类别(即类别的前景样本)不平衡问题,严重损害了一阶段检测器的性能。

Focal Loss是解决前景-背景不平衡问题的一种常规解决方法。它侧重于硬前景样本的学习,并减少了简单背景样本的影响。这种损失再分配技术在类别平衡分布下效果很好,但不足以处理长尾情况下前景类别间的不平衡问题。为了解决这个问题,作者从两阶段中现有的解决方案(如EQLv2)开始,将它们调整在一阶段检测器中一起处理Focal Loss。作者发现这些解决方案与它们在两阶段检测器上的应用相比,只带来了微小的改进(见表1)。然后,作者认为,简单地结合现有的解决方案与Focal Loss,不能同时解决这两种类型的不平衡问题。通过比较不同数据分布中正样本与负样本的比值(见图2),进一步认识到这些不平衡问题的本质是类别之间的正负不平衡程度不一致。罕见类别比频繁类别遭受更严重的正负失衡,因此需要更多的重视。

在本文中,提出了均衡Focal Loss(EFL),通过将一个类别相关的调制因子引入Focal Loss。具有两个解耦的动态因子(即聚焦因子和加权因子)的调制因子独立处理不同类别的正负不平衡。focusing factor根据硬正样本对应类别的不平衡程度,决定了对硬正样本的学习集中度。加权因子增加了稀有类别的影响,确保了稀有样本的损失贡献不会被频繁的样本所淹没。这两个因素的协同作用使EFL在长尾场景中应用一阶段检测器时,能够均匀地克服前景-背景不平衡和前景类别不平衡。

在具有挑战性的LVISv1基准测试上进行了广泛的实验。通过简单有效的起始训练,达到29.2%的AP,超过了现有的长尾目标检测方法。在开放图像上的实验结果也证明了方法的泛化能力。

综上所述,主要贡献可以总结如下:

  1. 是第一个研究单阶段长尾目标检测的工作;
  2. 提出了一种新的均衡Focal Loss,它用一个类别相关的调制因子扩展了原始的Focal Loss。它是一种广义的Focal Loss形式,可以同时解决前景-背景不平衡和前景类别不平衡的问题;
  3. 在LVISv1基准测试上进行了广泛的实验,结果证明了方法的有效性。它建立了一种新的先进技术,可以很好地应用于任何单阶段检测器。

2相关工作


2.1 普通目标检测

近年来,由于卷积神经网络的巨大成功,计算机视觉社区在目标检测方面有了显著的进步。现代目标检测框架大致可分为两阶段方法和一阶段方法。

1、两阶段目标检测

随着快速RCNN的出现,两阶段方法在现代目标检测中占据了主导地位。两阶段检测器首先通过区域建议机制(如选择性搜索或RPN)生成目标建议,然后根据这些建议执行特征图的空间提取,以便进一步预测。由于这个建议机制,大量的背景样本被过滤掉了。在之后,大多数两阶段检测器的分类器在前景和背景样本的相对平衡的分布上进行训练,比例为1:3。

2、一阶段目标检测

一般来说,一阶段目标检测器有一个简单和快速的训练管道,更接近真实世界的应用。在单阶段的场景中,检测器直接从特征图中预测检测框。一阶段目标检测器的分类器在一个有大约104到105个候选样本的密集集合上进行训练,但只有少数候选样本是前景样本。有研究试图从困难样本挖掘的视角或更复杂的重采样/重称重方案来解决极端的前景-背景不平衡问题。

Focal Loss及其衍生重塑了标准的交叉熵损失,从而减轻了分配给良好分类样本的损失,并集中对硬样本进行训练。受益于Focal Loss,一阶段目标检测器实现了非常接近两阶段目标检测器方法的性能,同时具有更高的推理速度。最近,也有学者尝试从标签分配的角度来提高性能。

而本文提出的EFL可以很好地应用于这些单阶段框架,并在长尾场景中带来显著的性能提高。

2.2  Long-Tailed 目标检测

与一般的目标检测相比,长尾目标检测是一项更加复杂的任务,因为它存在着前景类别之间的极端不平衡。解决这种不平衡的一个直接方法是在训练期间执行数据重采样。重复因子采样(RFS)对来自尾部类的训练数据进行过采样,而对来自图像级的头部类的训练数据进行过采样。

有学者以解耦的方式训练检测器,并从实例级别提出了一个具有类平衡采样器的额外分类分支。Forest R-CNN重新采样从RPN与不同的NMS阈值的建议。其他工作是通过元学习方式或记忆增强方式实现数据重采样。

损失重加权是解决长尾分布问题的另一种广泛应用的解决方案。有研究者提出了均衡损失(EQL),它减轻了头类对尾类的梯度抑制。EQLv2是EQL的一个升级版,它采用了一种新的梯度引导机制来重新衡量每个类别的损失贡献。

也有学者从无统计的角度解释了长尾分布问题,并提出了自适应类抑制损失(ACSL)。DisAlign提出了一种广义的重加权方法,在损失设计之前引入了一个平衡类。除了数据重采样和损失重加权外,许多优秀的工作还从不同的角度进行了尝试,如解耦训练、边缘修改、增量学习和因果推理。

然而,所有这些方法都是用两阶段目标检测器开发的,到目前为止还没有关于单阶段长尾目标检测的相关工作。本文提出了基于单阶段的长尾目标检测的第一个解决方案。它简单而有效地超越了所有现有的方法。


3本文方法


3.1 再看Focal Loss

在一阶段目标检测器中,Focal Loss是前景-背景不平衡问题的解决方案。它重新分配了易样本和难样本的损失贡献,大大削弱了大多数背景样本的影响。二分类Focal Loss公式为:

image.png

表示一个候选目标的预测置信度得分,而术语是平衡正样本和负样本的重要性的参数。调节因子是Focal Loss的关键组成部分。通过预测分数和Focal参数,降低了简单样本的损失,侧重于困难样本的学习。

大量的阴性样本易于分类,而阳性样本通常很难分类。因此,阳性样本与阴性样本之间的不平衡可以大致看作是容易样本与困难样本之间的不平衡。Focal参数决定了Focal Loss的影响。它可以从等式中得出结论:一个大的将大大减少大多数阴性样本的损失贡献,从而提高阳性样本的影响。这一结论表明,阳性样本与阴性样本之间的不平衡程度越高,的期望值越大。

当涉及到多类情况时,Focal Loss被应用于C分类器,这些分类器作用于每个实例的s型函数转换的输出日志。C是类别的数量,这意味着一个分类器负责一个特定的类别,即一个二元分类任务。由于Focal Loss同样对待具有相同调制因子的所有类别的学习,因此它未能处理长尾不平衡问题(见表2)。

3.2 Equalized Focal Loss

在长尾数据集(即LVIS)中,除了前景-背景不平衡外,一阶段检测器的分类器还存在前景类别之间的不平衡。

image.png

如图2所示,如果从y轴上看,正样本与负样本的比值远小于零,这主要揭示了前景和背景样本之间的不平衡。这里将该比值的值称为正负不平衡度。从x轴的角度可以看出,不同类别之间的不平衡程度存在很大差异,说明前景类别之间的不平衡。

显然,在数据分布(即COCO)中,所有类别的不平衡程度是相似的。因此,Focal Loss使用相同的调制因子就足够了。相反,这些不平衡的程度在长尾数据的情况下是不同的。罕见类别比常见类别遭受更严重的正负失衡。如表1中所示。大多数一阶段检测器在罕见类别上的表现比在频繁类别上更差。这说明,同一调制因子并不适用于所有不同程度的不平衡问题。

1、Focusing Factor

在此基础上,提出了均衡Focal Loss(EFL),该方法采用类别相关Focusing Factor来解决不同类别的正负不平衡。将第类的损失表述为:

image.png

其中和与在Focal Loss中的相同。

参数是第类的Focusing Factor,它在Focal Loss中起着与类似的作用。正如在前面提到的不同的值对应于不同程度的正向-负向不平衡问题。这里采用一个大的来缓解严重的正负失衡问题。对于有轻微不平衡的频繁类别,一个小的是合适的。Focusing Factor被解耦为2个组件,特别是一个与类别无关的参数和一个与类别相关的参数:

image.png

其中,表示控制分类器基本行为的平衡数据场景中的Focusing Factor。参数≥0是一个与第类不平衡度相关的变量参数。它决定了学习的注意力集中在正负不平衡问题上。受EQLv2的启发,采用了梯度引导的机制来选择。参数表示第类正样本与负样本的累积梯度比。

值较大表示第j类(例如频繁)是训练平衡,较小值表示类别(例如罕见)是训练不平衡。为了满足对的要求,将的值定义在[0,1]范围内,并采用来反转其分布。超参数s是决定EFL中上限的缩放因子。与Focal Loss相比,EFL可以独立处理每个类别的正负不平衡问题,从而带来性能的提升(见表3)。

2、Weighting Factor

即使使用了Focusing Factor ,仍然有2个障碍损失的性能:

  1. 对于二元分类任务,更大的适用于更严重的正负不平衡问题。而在多类的情况下,如图3a所示,对于相同的,的值越大,损失就越小。这导致了这样一个事实:当想要增加对学习一个具有严重的正负不平衡的类别的注意力时,必须牺牲它在整个训练过程中所做的部分损失贡献。这种困境阻碍了稀有类别获得优异的表现。
  2. 当较小时,来自不同Focusing Factor的不同类别样本的损失将收敛到一个相似的值。实际上,期望罕见的困难样本比频繁的困难样本做出更多的损失贡献,因为它们是稀缺的,并且不能主导训练过程。

于是作者提出了Weighting Factor,通过重新平衡不同类别的损失贡献来缓解上述问题。与Focusing Factor相似为罕见类别分配了一个较大的权重因子值,以提高其损失贡献,同时保持频繁类别的权重因子接近于1。具体地说,将第类的Weighting Factor设置为,以与Focusing Factor相一致。EFL的最终公式为:

image.png

如图3b所示,使用Weighting Factor,EFL显著增加了稀有类别的损失贡献。同时,与频繁的困难样本相比,它更侧重于罕见的困难样本的学习。

Focusing Factor和Weighting Factor构成了EFL的与类别相关的调节因子。它使分类器能够根据样本的训练状态和对应的类别状态动态调整样本的损失贡献。Focusing Factor和Weighting Factor在EFL中均有重要作用。同时,在平衡数据分布中,所有的EFL都相当于Focal Loss。这种吸引人的特性使得EFL可以很好地应用于不同的数据分布和数据采样器之中。

PyTorch实现如下:

@LOSSES_REGISTRY.register('equalized_focal_loss')
class EqualizedFocalLoss(GeneralizedCrossEntropyLoss):
    def __init__(self,
                 name='equalized_focal_loss',
                 reduction='mean',
                 loss_weight=1.0,
                 ignore_index=-1,
                 num_classes=1204,
                 focal_gamma=2.0,
                 focal_alpha=0.25,
                 scale_factor=8.0,
                 fpn_levels=5):
        activation_type = 'sigmoid'
        GeneralizedCrossEntropyLoss.__init__(self,
                                             name=name,
                                             reduction=reduction,
                                             loss_weight=loss_weight,
                                             activation_type=activation_type,
                                             ignore_index=ignore_index)
        # Focal Loss的超参数
        self.focal_gamma = focal_gamma
        self.focal_alpha = focal_alpha
        # ignore bg class and ignore idx
        self.num_classes = num_classes - 1
        # EFL损失函数的超参数
        self.scale_factor = scale_factor
        # 初始化正负样本的梯度变量
        self.register_buffer('pos_grad', torch.zeros(self.num_classes))
        self.register_buffer('neg_grad', torch.zeros(self.num_classes))
        # 初始化正负样本变量
        self.register_buffer('pos_neg', torch.ones(self.num_classes))
        # grad collect
        self.grad_buffer = []
        self.fpn_levels = fpn_levels
        logger.info("build EqualizedFocalLoss, focal_alpha: {focal_alpha}, focal_gamma: {focal_gamma},scale_factor: {scale_factor}")
    def forward(self, input, target, reduction, normalizer=None):
        self.n_c = input.shape[-1]
        self.input = input.reshape(-1, self.n_c)
        self.target = target.reshape(-1)
        self.n_i, _ = self.input.size()
        def expand_label(pred, gt_classes):
            target = pred.new_zeros(self.n_i, self.n_c + 1)
            target[torch.arange(self.n_i), gt_classes] = 1
            return target[:, 1:]
        expand_target = expand_label(self.input, self.target)
        sample_mask = (self.target != self.ignore_index)
        inputs = self.input[sample_mask]
        targets = expand_target[sample_mask]
        self.cache_mask = sample_mask
        self.cache_target = expand_target
        pred = torch.sigmoid(inputs)
        pred_t = pred * targets + (1 - pred) * (1 - targets)
  # map_val为:1-g^j
        map_val = 1 - self.pos_neg.detach()
        # dy_gamma为:gamma^j
        dy_gamma = self.focal_gamma + self.scale_factor * map_val
        # focusing factor
        ff = dy_gamma.view(1, -1).expand(self.n_i, self.n_c)[sample_mask]
        # weighting factor
        wf = ff / self.focal_gamma
        # ce_loss
        ce_loss = -torch.log(pred_t)
        cls_loss = ce_loss * torch.pow((1 - pred_t), ff.detach()) * wf.detach()
        if self.focal_alpha >= 0:
            alpha_t = self.focal_alpha * targets + (1 - self.focal_alpha) * (1 - targets)
            cls_loss = alpha_t * cls_loss
        if normalizer is None:
            normalizer = 1.0
        return _reduce(cls_loss, reduction, normalizer=normalizer)
 # 收集梯度,用于梯度引导的机制
    def collect_grad(self, grad_in):
        bs = grad_in.shape[0]
        self.grad_buffer.append(grad_in.detach().permute(0, 2, 3, 1).reshape(bs, -1, self.num_classes))
        if len(self.grad_buffer) == self.fpn_levels:
            target = self.cache_target[self.cache_mask]
            grad = torch.cat(self.grad_buffer[::-1], dim=1).reshape(-1, self.num_classes)
            grad = torch.abs(grad)[self.cache_mask]
            pos_grad = torch.sum(grad * target, dim=0)
            neg_grad = torch.sum(grad * (1 - target), dim=0)
            allreduce(pos_grad)
            allreduce(neg_grad)
   # 正样本的梯度
            self.pos_grad += pos_grad
            # 负样本的梯度
            self.neg_grad += neg_grad
            # self.pos_neg=g_j:表示第j类正样本与负样本的累积梯度比
            self.pos_neg = torch.clamp(self.pos_grad / (self.neg_grad + 1e-10), min=0, max=1)
            self.grad_buffer = []


相关文章
|
7月前
|
计算机视觉
如何理解focal loss/GIOU(yolo改进损失函数)
如何理解focal loss/GIOU(yolo改进损失函数)
|
7月前
|
机器学习/深度学习 监控 数据可视化
训练损失图(Training Loss Plot)
训练损失图(Training Loss Plot)是一种在机器学习和深度学习过程中用来监控模型训练进度的可视化工具。损失函数是衡量模型预测结果与实际结果之间差距的指标,训练损失图展示了模型在训练过程中,损失值随着训练迭代次数的变化情况。通过观察损失值的变化,我们可以评估模型的拟合效果,调整超参数,以及确定合适的训练停止条件。
1217 5
|
2月前
|
机器学习/深度学习 算法 PyTorch
深度学习笔记(十三):IOU、GIOU、DIOU、CIOU、EIOU、Focal EIOU、alpha IOU、SIOU、WIOU损失函数分析及Pytorch实现
这篇文章详细介绍了多种用于目标检测任务中的边界框回归损失函数,包括IOU、GIOU、DIOU、CIOU、EIOU、Focal EIOU、alpha IOU、SIOU和WIOU,并提供了它们的Pytorch实现代码。
206 1
深度学习笔记(十三):IOU、GIOU、DIOU、CIOU、EIOU、Focal EIOU、alpha IOU、SIOU、WIOU损失函数分析及Pytorch实现
|
4月前
|
机器学习/深度学习 算法 Serverless
三元组损失Triplet loss 详解
在这篇文章中,我们将以简单的技术术语解析三元组损失及其变体批量三元组损失,并提供一个相关的例子来帮助你理解这些概念。
74 2
|
7月前
|
机器学习/深度学习
损失函数大全Cross Entropy Loss/Weighted Loss/Focal Loss/Dice Soft Loss/Soft IoU Loss
损失函数大全Cross Entropy Loss/Weighted Loss/Focal Loss/Dice Soft Loss/Soft IoU Loss
148 2
Focal Loss升级 | E-Focal Loss让Focal Loss动态化,类别极端不平衡也可以轻松解决(二)
Focal Loss升级 | E-Focal Loss让Focal Loss动态化,类别极端不平衡也可以轻松解决(二)
204 0
|
机器学习/深度学习 PyTorch 算法框架/工具
深入理解二分类和多分类CrossEntropy Loss和Focal Loss
多分类交叉熵就是对二分类交叉熵的扩展,在计算公式中和二分类稍微有些许区别,但是还是比较容易理解
1439 0
|
数据可视化 计算机视觉 异构计算
VariFocalNet | IoU-aware同V-Focal Loss全面提升密集目标检测(附YOLOV5测试代码)(二)
VariFocalNet | IoU-aware同V-Focal Loss全面提升密集目标检测(附YOLOV5测试代码)(二)
381 1
|
算法 数据挖掘 计算机视觉
目标检测中 Anchor 与 Loss 计算
目标检测中 Anchor 与 Loss 计算
188 0
【学习】loss图和accuracy
【学习】loss图和accuracy
393 0