先附上一张大家熟悉的CIOU图
我们知道,CIOU是在IOU的基础上考虑了两个真实框和预测框的中心点距离(图中的d)还有两个框的最小包裹框的对角线距离(图中的c,最小包裹矩形框是图中虚线部分)。比如当两个框不重叠的时候,那么此时的IOU是等于0的,这样会导致无法进行反向传播,而CIOU可以很好的解决这一问题。
下面我们看看从代码如何实现,我们网络的输出特征层一般是这种shape 【batch_size,feature_W,featuer_H,num_anchors,4】。比如yolo,那么会得到一个[batch_size,13,13,507,4]这样的张量。13是特征层大小,然后有3种锚框,因此该特征层有13*13*3=507个框,每个框又有4个值(center_x,center_y,w,h)表示框的中心点坐标和框的宽和高。
为了方便起见,那么我们现在假设一下,我现在只有两个框,一个真实框(自己标注的),一个预测框,不去管特征层尺寸,只是处理框,那么可以有以下两个张量。其中box1是预测框,box2是真实框(这里的数值没有做归一化处理,不影响理解)。这两个矩形框如下图(绿色真实框):
box1 = torch.tensor([[5.4555e+01, 1.7518e+01, 4.0713e+01, 3.3931e+01]]) # 预测框 box2 = torch.tensor([[7.8304e+01, 1.6306e+01, 4.9968e+01, 3.6646e+01]]) # 真实框 fig = plt.figure() ax = fig.add_subplot(111) rec1 = plt.Rectangle((box1[:, 0].numpy() - box1[:, 2].numpy() / 2, box1[:, 1].numpy() - box1[:, 3].numpy() / 2), box1[:, 2].numpy(), box1[:, 3].numpy(), fill=False) # 预测框 rec2 = plt.Rectangle((box2[:, 0].numpy() - box2[:, 2].numpy() / 2, box2[:, 1].numpy() - box2[:, 3].numpy() / 2), box2[:, 2].numpy(), box2[:, 3].numpy(), fill=False, color='g') # 真实框 ax.add_patch(rec1) ax.add_patch(rec2) plt.xlim(-10, 150) plt.ylim(-20, 110) plt.gca().invert_yaxis() plt.show()
然后我们可以连接一下这两个框的对角线
plt.plot((box1[:, 0].numpy(), box2[:, 0].numpy()), (box1[:, 1], box2[:, 1]), color='blue') # 连接中心点
因为张量给的box信息是xywh接下来是计算这两个框左上角和右下角坐标
那么我们如何计算呢?
我们可以观察box发现:
左上角x坐标=中心点x坐标-w/2;
左上角y坐标=中心点y坐标-h/2;
右下角x坐标=中心点x坐标+w/2;
右下角y坐标=中心点y坐标+h/2;
因此代码可以这样写:
b1_xy = box1[:, :2] # 取预测框的中心点 b1_wh = box1[:, 2:] # 取预测框的wh b1_wh_half = b1_wh / 2. b1_mins = b1_xy - b1_wh_half # 求得左上角坐标 b1_maxes = b1_xy + b1_wh_half # 求得右下角坐标 b2_xy = box2[:, :2] # 取真实框的中心点 b2_wh = box2[:, 2:] # 取真实框的wh b2_wh_half = b2_wh / 2. b2_mins = b2_xy - b2_wh_half # 求得左上角坐标 b2_maxes = b2_xy + b2_wh_half # 求得右下角坐标
然后接下来我们需要计算两个框相交的坐标,为计算IOU做准备。然后分析一下如何计算相交部分的左上角坐标和右下坐标,通过观察两个框的相交,可以发现:
相交左上角=max(box1左上坐标, box2左上坐标)
相交右下角=min(box1右下角坐标,box2右下角坐标)
因此代码可以这样写:
intersect_mins = torch.max(b1_mins, b2_mins) # 相交的左上角 intersect_max = torch.min(b1_maxes, b2_maxes) # 相交的右下角 intersect_wh = torch.max(intersect_max - intersect_mins, torch.zeros_like(intersect_max)) # 求得相交的w和h
其中上面代码中的 torch.zeros_like(intersect_max) 是防止两个框不相交,可令wh为0【相交面积也自然为0了】
那么我们就可以得到相交的面积了:
intersect_area = intersect_wh[:, 0] * intersect_wh[:, 1] # 相交面积
然后是计算两个框的并集面积,并集面积可以这样算:
并集面积=box1面积+box2面积-相交面积
因此代码如下:
b1_area = b1_wh[:, 0] * b1_wh[:, 1] # 预测框面积 b2_area = b2_wh[:, 0] * b2_wh[:, 1] # 真实框面积 union_area = b1_area + b2_area - intersect_area # 并集面积
然后我们就计算普通的IOU了(IOU就是交集面积和并集面积之比)
iou = intersect_area / union_area # 计算IOU
此时的iou为【大家可以在纸上计算一下】:
iou = tensor([0.2954])
然后是计算两个框的中心点的欧式距离,即公式中的。
# 计算中心差距d center_distance = torch.sum(torch.pow(b1_xy - b2_xy, 2), axis=-1) # 先求平方再相加,得到欧氏距离axis=-1是最后一个维度操作 d(b,bgt)
然后是计算包含两个框最小矩形的左上角和右下角(为计算对角线距离做准备),然后可以分析一下这个最小矩形怎么找,坐标怎么算。通过观察,可以发现:
最小矩形的左上角=min(box1_左上角,box2_左上角)
最小矩形的右下角=max(box1_右下角,box2_右下角)
# 计算包含两个框的最小框的左上角和右下角 closebox_min = torch.min(b1_mins, b2_mins) # 左上角 closebox_max = torch.max(b1_maxes, b2_maxes) # 右下角 closebox_wh = torch.max(closebox_max - closebox_min, torch.zeros_like(intersect_max))
然后我们就可以将其绘制一下,最小矩形对角线是图中浅蓝色部分 ,图中虚线是所求最小矩矩形。(绿色是真实框)
然后可以计算一个这个对角线距离:
# 计算对角线的距离 closebox_distance = torch.sum(torch.pow(closebox_max - closebox_min, 2), axis=-1)
我们现在已经得到了公式中所需要的所有数据了,接下来就是把上面所求组合成CIOU的公式。
因为我们上面已经计算过IOU了,现在计算一下公式中的【代码较长,大家对着公式看】,其中1e+6是为了防止分母出现0而已。
v = math.pow(math.atan(b2_wh[:, 0] / (b2_wh[:, 1] + 1e-6)) - math.atan(b1_wh[:, 0] / (b1_wh[:, 1] + 1e-6)), 2) * 4/math.pow(math.pi, 2)
那么也可以得到了:
alpha = v / (1 - iou + v)
然后我们就可以得到CIOU的整体公式了(代码中的clamp是将输入限定在一个区域,限制最小为1e-8也是为了防止分母为0):
ciou = iou - center_distance / torch.clamp(closebox_distance, min=1e-8) - alpha * v
此时的CIOU损失函数为loss_ciou = 1-ciou,所以:
loss_ciou = 1 - ciou
我们可以打印一下此刻的loss为多少:
loss_ciou = tensor([0.7970])
其实到这就已经可以了,不过我们可以在做一个小小的实验,如果我们把预测框(最开始图中的黑色框)往右边平移5个单位,即让预测框与真实框重合度增大(说明预测的又准了一点点),然后我们看下loss会怎么变。
iou = tensor([0.3905]) ciou = tensor([0.3258]) loss_ciou = tensor([0.6742])
我们可以看到,loss确实降低了,iou比之前上升了,说明我们的预测框确实离真实框距离又进了一步,预测也更准了那么一点点。
希望大家可以更好的理解CIOU的代码实现和分析过程。
完整代码:
import math import torch import matplotlib.pyplot as plt box1 = torch.tensor([[5.4555e+01, 1.7518e+01, 4.0713e+01, 3.3931e+01]]) # 预测框 box2 = torch.tensor([[7.8304e+01, 1.6306e+01, 4.9968e+01, 3.6646e+01]]) # 真实框 fig = plt.figure() ax = fig.add_subplot(111) rec1 = plt.Rectangle((box1[:, 0].numpy() - box1[:, 2].numpy() / 2, box1[:, 1].numpy() - box1[:, 3].numpy() / 2), box1[:, 2].numpy(), box1[:, 3].numpy(), fill=False) # 预测框 rec2 = plt.Rectangle((box2[:, 0].numpy() - box2[:, 2].numpy() / 2, box2[:, 1].numpy() - box2[:, 3].numpy() / 2), box2[:, 2].numpy(), box2[:, 3].numpy(), fill=False, color='g') # 真实框 plt.plot((box1[:, 0].numpy(), box2[:, 0].numpy()), (box1[:, 1], box2[:, 1]), color='blue') # 连接中心点 ax.add_patch(rec1) ax.add_patch(rec2) # plt.xlim(-10, 150) # plt.ylim(-20, 110) # plt.gca().invert_yaxis() # plt.show() b1_xy = box1[:, :2] # 取预测框的中心点 b1_wh = box1[:, 2:] # 取预测框的wh b1_wh_half = b1_wh / 2. b1_mins = b1_xy - b1_wh_half # 求得左上角坐标 b1_maxes = b1_xy + b1_wh_half # 求得右下角坐标 b2_xy = box2[:, :2] # 取真实框的中心点 b2_wh = box2[:, 2:] # 取真实框的wh b2_wh_half = b2_wh / 2. b2_mins = b2_xy - b2_wh_half # 求得左上角坐标 b2_maxes = b2_xy + b2_wh_half # 求得右下角坐标 intersect_mins = torch.max(b1_mins, b2_mins) # 相交的左上角 intersect_max = torch.min(b1_maxes, b2_maxes) # 相交的右下角 intersect_wh = torch.max(intersect_max - intersect_mins, torch.zeros_like(intersect_max)) # 求得相交的w和h intersect_area = intersect_wh[:, 0] * intersect_wh[:, 1] # 相交面积 b1_area = b1_wh[:, 0] * b1_wh[:, 1] # 预测框面积 b2_area = b2_wh[:, 0] * b2_wh[:, 1] # 真实框面积 union_area = b1_area + b2_area - intersect_area # 并集面积 iou = intersect_area / union_area # 计算IOU print("iou = ", iou) # CIOU = IOU - d(b,bgt)/c^2 - αv # 计算中心差距d center_distance = torch.sum(torch.pow(b1_xy - b2_xy, 2), axis=-1) # 先求平方再相加,得到欧氏距离axis=-1是最后一个维度操作 d(b,bgt) # 计算包含两个框的最小框的左上角和右下角 closebox_min = torch.min(b1_mins, b2_mins) # 左上角 closebox_max = torch.max(b1_maxes, b2_maxes) # 右下角 closebox_wh = torch.max(closebox_max - closebox_min, torch.zeros_like(intersect_max)) plt.plot((closebox_min[:, 0], closebox_max[:, 0]), (closebox_min[:, 1], closebox_max[:, 1])) # 绘制对角线 # 计算对角线的距离 closebox_distance = torch.sum(torch.pow(closebox_max - closebox_min, 2), axis=-1) # 计算ciou v = math.pow(math.atan(b2_wh[:, 0] / (b2_wh[:, 1] + 1e-6)) - math.atan(b1_wh[:, 0] / (b1_wh[:, 1] + 1e-6)), 2) * 4/math.pow(math.pi, 2) alpha = v / (1 - iou + v) ciou = iou - center_distance / torch.clamp(closebox_distance, min=1e-8) - alpha * v print("ciou = ",ciou) loss_ciou = 1 - ciou print("loss_ciou = ",loss_ciou) rec3 = plt.Rectangle((closebox_min[:, 0].numpy(), closebox_min[:, 1].numpy()), # xy closebox_max[:, 0].numpy() - closebox_min[:, 0].numpy(), # w closebox_max[:, 1].numpy() - closebox_min[:, 1].numpy(), fill=False, linestyle='dotted') # h ax.add_patch(rec3) plt.xlim(20, 120) plt.ylim(-5, 40) plt.gca().invert_yaxis() plt.show()