DETR | 基于匈牙利算法的样本分配策略

简介: DETR | 基于匈牙利算法的样本分配策略

前不久,沐神对DETR进行了讲解,其实之前也对DETR进行了介绍,见:论文阅读笔记 | 目标检测算法——DETR。现对DETR的核心内容进行重温,也就是其所提出的目标检测的end-to-end框架,输入的是一张图像,输出的直接是最后的预测标注结果,不再需要后处理(nms非极大值抑制)。


为什么希望end-to-end,是因为nms对于模型的调参比较复杂。而且因为不是所有的硬件都支持这nms,所有即使训练好了模型,部署起来也比较困难。DETR既不需要proposal,也不需要anchor,直接利用transformer的全局建模能力,将目标检测看成是一个集合预测的问题。所有,其实从另一个方面来说,DETR可以看成是一个框架,一种用二分图匹配算法进行一个集合预测的框架方法。


1. DERT回顾


为了实现end-to-end的目标,DETR需要对预测对象和ground truth之间进行二分图匹配,从而计算损失。具体来说,在paper中设定了一个固定的object query=100,也就是每一张图像都输出100个最后的预测结果。假设这张图像只有2个ground truth,在训练过程中就是对这输出的100个object query和这两个ground truth框计算一个match loss,从而决定出这100个object query中哪两个框是独一无二的一一的分别对应到这两个ground truth上。一旦决定好了这种一一对应的匹配关系之后,才会像普通的目标检测一样去算一个分类的loss,再算一个bounding box回归的loss。而那些剩下的没有匹配到ground truth的98个框,就被被标记为∅,没有物体。


DERT处理流程图:

image.png


总结DETR训练的大致过程,分为以下四个步骤:


  1. 用卷积神经网络抽特征
  2. 用Transformer Encoder去学全局特征,帮助后面做检测
  3. 用Transformer Decoder调整object query生成100个预测框
  4. 用输出的100个object query框和ground truth框做一个匹配,然后在一一配对好的框中去计算目标检测的loss(分类loss与回归loss)


对于推理过程,其和训练过程的前3步都是一致的,因为训练的时候需要算这个二分图匹配的loss,推理时候这个loss是不需要的,直接在最后的输出上用一个阈值卡一下模型的置信度。然后最后模型置信度高的(置信度大于0.7)预测就会被保留下来。


所以,整个过程都可以看出,这里完全没有anchor生成与nms后处理的流程。


2. DETR中的二分图匹配


在上述我们讲到整个DETR的正负样本分配的重点就是使用了二分图匹配,来对预测框与gorund truth进行了一个最佳匹配。那么解决这两个集合之间的一个最佳匹配的方式,使得cost代价最少的问题,就被称为二分图匹配问题。


举一个二分图匹配的例子,如何分配一些工人干一些活,从而能让最后的支出最小。因为每个工人有各自的长处与短处,所以他们干不同活所需要的回报不同,那么最后每个工人对应每个任务就形成了一个n工人与n任务的一个nxn矩阵,这个矩阵称为cost matric。最优二分图匹配的意思就是最后能够找到一个唯一解,能够给每个人都能分配其最擅长的一份工作,使得把这三个工作完成最后的价钱最低。匈牙利算法就是解决二分图匹配问题中一个比较有名的算法。

image.png


那么,在目标检测中,这里的工人就可以看成是集合预测出来的预测框。因为这里有100个,所以就是有100个预测框,而这里的不同任务就对应着ground truth(同样设定为100,如果没有100则用∅填充),最后工人的价钱就是每个预测框与每个ground truth的匹配损失(一般是分类算是与回归损失),这里是损失计算为是:

image.png

把每个pred与ground truth计算器匹配,把cost matric填充完整,cost matric改变成如下所示:

image.png


现在对于这种二分图匹配的问题一般都要成熟的解决方案,比如scipy库中的 linear_sum_assigment 函数。这个函数的输入就是cost matrix,只需要把这个cost matrix传给他,就可以返回每一行(列)应该选择的索引。而且cost matrix也可也是长方形,也不一定是正方形,都是能算出一个最优匹配的。


也就是说,把每个pred与ground truth计算一一匹配的损失,把cost matric填充完整,就可以丢到scipy库中的 linear_sum_assigment 函数中,得到最后的最优解。这里的匹配方式约束更强,一定要得到这个一对一的匹配关系,也就是只有一个框与ground truth的一个框是对应的,这样后面才不需要去做那个后处理nms。


那么一旦知道了这100个框中有哪几个框是跟这个Ground Truth框是对应的,那么接下来就可以算一个真正的目标函数,然后用这个loss去进行梯度传播来更新模型的参数。不够这里是回归损失与分类损失是有改变的,对于分类损失这里去除了log,而对于回归损失是l1 loss与generalized iou loss联合来算bounding box loss。


这篇文章主要是介绍样本匹配,所以详细的损失函数与DETR结构介绍见:论文阅读笔记 | 目标检测算法——DETR


DETR的结构补充介绍

这里以一个 3x800x1066 大小的图像来进行一个简单的前向传播,如下所示:

image.png


在论文中使用了6个Transformer Encoder与6个Transformer Decoder,每层的输入与输出维度是不变的。所以object query这个可学习的embedding,在不断的与Encoder的全局编码信息作注意力交互时,始终保持其维度信息 100x256 的不变。接着就是丢进去一个权重共享的 feet forward network 中,结果与ground truth做匈牙利算法匹配,再利用与ground truth一一匹配好的唯一预测来计算损失,然后进行梯度回传更新参数。


细节1:在每一个Decoder中,都会先做一次这个object query的自注意力操作,但是第一层可以不用做,后面的几层全部都需要做。object query的自注意力操作目的是为了移出这种冗余框,因为其之间相互通信之后就知道每个query可能得到怎么样的一个框,然后进行不要去做重复的框。


细节2:在最后做loss的时候,为了让模型收敛得更快以及训练得更稳定,作者在Decoder后面添加了许多auxiliary loss(辅助损失)。意思就是不光在最后一层的Decoder输出中计算loss,在前面的每一层Decoder中也可以计算loss,因为每个decoder的输出都是相同的维度的,都是 100x256 。也就是都可以把这个结果丢进去给FFN然后得到输出。所以作者这里在6层的Decoder后面,都加了FFN,而去得到了这个目标检测的输出去计算loss。当然,这些FFN都是共享参数的。


3. DETR中的匈牙利算法实现


项目地址:https://github.com/facebookresearch/detr


参考代码:


from scipy.optimize import linear_sum_assignment
...
# 匈牙利算法实现二分图匹配
class HungarianMatcher(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network
    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """
    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
        """Creates the matcher
        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox
        self.cost_giou = cost_giou
        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
    @torch.no_grad()
    def forward(self, outputs, targets):
        """ Performs the matching
        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
                           objects in the target) containing the class labels
                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
        """
        bs, num_queries = outputs["pred_logits"].shape[:2]
        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]
        # Also concat the target labels and boxes
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_bbox = torch.cat([v["boxes"] for v in targets])
        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        cost_class = -out_prob[:, tgt_ids]
        # Compute the L1 cost between boxes
        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
        # Compute the giou cost betwen boxes
        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
        # Final cost matrix
        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
        C = C.view(bs, num_queries, -1).cpu()
        sizes = [len(v["boxes"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
# 损失计算
class SetCriterion(nn.Module):
    """ This class computes the loss for DETR.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
    """
    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)
    def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o
        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}
        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
        return losses
    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_boxes):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses
    def loss_boxes(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
        """
        assert 'pred_boxes' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs['pred_boxes'][idx]
        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
        losses = {}
        losses['loss_bbox'] = loss_bbox.sum() / num_boxes
        loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
            box_ops.box_cxcywh_to_xyxy(src_boxes),
            box_ops.box_cxcywh_to_xyxy(target_boxes)))
        losses['loss_giou'] = loss_giou.sum() / num_boxes
        return losses
    def loss_masks(self, outputs, targets, indices, num_boxes):
        """Compute the losses related to the masks: the focal loss and the dice loss.
           targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
        """
        assert "pred_masks" in outputs
        src_idx = self._get_src_permutation_idx(indices)
        tgt_idx = self._get_tgt_permutation_idx(indices)
        src_masks = outputs["pred_masks"]
        src_masks = src_masks[src_idx]
        masks = [t["masks"] for t in targets]
        # TODO use valid to mask invalid areas due to padding in loss
        target_masks, valid = nested_tensor_from_tensor_list(masks).decompose()
        target_masks = target_masks.to(src_masks)
        target_masks = target_masks[tgt_idx]
        # upsample predictions to the target size
        src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
                                mode="bilinear", align_corners=False)
        src_masks = src_masks[:, 0].flatten(1)
        target_masks = target_masks.flatten(1)
        target_masks = target_masks.view(src_masks.shape)
        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
            "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
        }
        return losses
    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx
    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx
    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'boxes': self.loss_boxes,
            'masks': self.loss_masks
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)
        # Compute the average number of target boxes accross all nodes, for normalization purposes
        num_boxes = sum(len(t["labels"]) for t in targets)
        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
        if is_dist_avail_and_initialized():
            torch.distributed.all_reduce(num_boxes)
        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    if loss == 'masks':
                        # Intermediate masks losses are too costly to compute, we ignore them.
                        continue
                    kwargs = {}
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs = {'log': False}
                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)
        return losses


4. 匈牙利算法解析


详细见参考资料2.


这里再用一个例子,来看匈牙利算法是如何进行的。


现在有五个工作(搬砖、送快递、洗衣服、打扫、做饭)需要安排给有5个宝可梦(皮卡丘、杰尼龟、喷火龙、小拳石、妙蛙草)。每个宝可梦对每一项工作收费标准不同。如何安排工作使得成本最低。


(注:①每个宝可梦只能做一项工作;②每项工作只能分配给一个宝可梦做;③所有工作都要安排完。)

image.png


根据这个任务就可以抽象为一个数学问题,可以表述为以下的数学公式:

image.png

前置知识:指派问题的最优解有这样一个性质,若从系数矩阵的一行(列)各元素中分别减去该行(列)的最小元素,得到新矩阵,那么以新矩阵为系数矩阵求得的最优解和用原矩阵求得的最优解相同.利用这个性质,可使原系数矩阵变换为含有很多0元素的新矩阵,而最优解保持不变。


匈牙利算法处理流程如下:


  1. 归约

step1:行归约(使得每行至少有一个零)

image.png


step2:列归约(使得每列至少有一个零)

此处由于每列恰好最小值已为零,故列规约后结果不变。

image.png


  1. 试指派(找到归约后的成本矩阵中独立的零)

step1:找到含0数目最少的行或列(不妨取行) 随后将该行第一个零置为“T0”,随后将“T0”所在行和列中其他的零置“F0”。依次类推,完成归约矩阵所有行的操作。


以上述的归约矩阵为例:

image.png


step2:用最少的直线来覆盖矩阵中所有的零。

具体方法:


① 对没有T0的行用★进行标


② 对★所标记的行中存在的F0所在的列索引进行标记(同样标记★)


③ 对★所标记的列中,对T0所在的行索引进行标记(同样标记★)


④ 重复2、3步骤,直至找不到可以标记的行和列


①~④步骤以上述试指派为例

image.png


⑤ 对没有标记的行画横线表示去掉这一行,对标记的列画横线表示去掉这一列,这样就得到能覆盖所有0 的最小横线。

image.png

step3:变换试指派矩阵,增加其中的0元素

具体方法:


① 在未被直线覆盖的所有元素中找到min


② 在被★标记的所有行中减去这个元素;


③ 在被★标记的所有列中加上这个元素(保证原来的零不变);


④ 得到新的归约矩阵。返回step1。直至满足约束条件

image.png


针对以上例子,进行代码测试:


from scipy.optimize import linear_sum_assignment
import numpy as np
cost_mat = np.array([[12, 7, 9, 7, 9],
                     [8, 9, 6, 6, 6],
                     [7, 17, 12, 14, 9],
                     [15, 14, 6, 6, 10],
                     [4, 10, 7, 10, 9]])
index, target = linear_sum_assignment(cost_mat)
# idnex: (array([0, 1, 2, 3, 4])
# target: array([1, 2, 4, 3, 0]))
cost_ass = cost_mat.copy()
cost_ass = cost_mat.copy()
for i, t in zip(index, target):
    cost_ass[i][t] = 0
print(cost_ass)
# 输出: 赋予0值的就是被宝可梦被赋予的工作
# array([[12,  0,  9,  7,  9],
#        [ 8,  9,  0,  6,  6],
#        [ 7, 17, 12, 14,  0],
#        [15, 14,  6,  0, 10],
#        [ 0, 10,  7, 10,  9]])
# 代价
cost_mat[index, target].sum()
# 输出: 32


参考资料:


1. DETR 论文精读【论文精读】


2. 匈牙利算法与python实现


目录
相关文章
|
6月前
|
人工智能 自然语言处理 算法
当prompt策略遇上分治算法,南加大、微软让大模型炼成“火眼金睛”
【2月更文挑战第24天】当prompt策略遇上分治算法,南加大、微软让大模型炼成“火眼金睛”
60 2
当prompt策略遇上分治算法,南加大、微软让大模型炼成“火眼金睛”
|
26天前
|
数据采集 缓存 算法
算法优化的常见策略有哪些
【10月更文挑战第20天】算法优化的常见策略有哪些
|
6月前
|
存储 监控 NoSQL
Redis处理大量数据主要依赖于其内存存储结构、高效的数据结构和算法,以及一系列的优化策略
【5月更文挑战第15天】Redis处理大量数据依赖内存存储、高效数据结构和优化策略。选择合适的数据结构、利用批量操作减少网络开销、控制批量大小、使用Redis Cluster进行分布式存储、优化内存使用及监控调优是关键。通过这些方法,Redis能有效处理大量数据并保持高性能。
97 1
|
3月前
|
缓存 算法 前端开发
深入理解缓存淘汰策略:LRU和LFU算法的解析与应用
【8月更文挑战第25天】在计算机科学领域,高效管理资源对于提升系统性能至关重要。内存缓存作为一种加速数据读取的有效方法,其管理策略直接影响整体性能。本文重点介绍两种常用的缓存淘汰算法:LRU(最近最少使用)和LFU(最不经常使用)。LRU算法依据数据最近是否被访问来进行淘汰决策;而LFU算法则根据数据的访问频率做出判断。这两种算法各有特点,适用于不同的应用场景。通过深入分析这两种算法的原理、实现方式及适用场景,本文旨在帮助开发者更好地理解缓存管理机制,从而在实际应用中作出更合理的选择,有效提升系统性能和用户体验。
191 1
|
3月前
|
算法 语音技术
支付宝商业化广告算法问题之在ODL模型优化过程中,采取什么策略来提高模型的泛化能力呢
支付宝商业化广告算法问题之在ODL模型优化过程中,采取什么策略来提高模型的泛化能力呢
|
3月前
|
算法 Java 应用服务中间件
探索JVM垃圾回收算法:选择适合你应用的最佳GC策略
探索JVM垃圾回收算法:选择适合你应用的最佳GC策略
|
4月前
|
机器学习/深度学习 数据采集 算法
Python实现ISSA融合反向学习与Levy飞行策略的改进麻雀优化算法优化支持向量机回归模型(SVR算法)项目实战
Python实现ISSA融合反向学习与Levy飞行策略的改进麻雀优化算法优化支持向量机回归模型(SVR算法)项目实战
170 9
|
4月前
|
机器学习/深度学习 数据采集 算法
Python实现ISSA融合反向学习与Levy飞行策略的改进麻雀优化算法优化支持向量机分类模型(SVC算法)项目实战
Python实现ISSA融合反向学习与Levy飞行策略的改进麻雀优化算法优化支持向量机分类模型(SVC算法)项目实战
|
4月前
|
缓存 算法 安全
Java中的数据结构与算法优化策略
Java中的数据结构与算法优化策略
|
5月前
|
算法 C++
算法与数据结构高手养成:朴素的贪心法(上)最优化策略
算法与数据结构高手养成:朴素的贪心法(上)最优化策略
38 1
算法与数据结构高手养成:朴素的贪心法(上)最优化策略