在上一篇文章中有讲解以分类网络为例的知识蒸馏【分类网络知识蒸馏】,这篇文章将会针对目标检测网络进行蒸馏。
知识蒸馏是一种不改变网络结构模型压缩方法。这里的压缩需要和量化与剪枝进行区分,并不是严格意义上的压缩。这里将要讲的蒸馏是离线式蒸馏中的逻辑蒸馏【特征部分蒸馏以后会讲】,也是一种常用的方法,他是将已经训练好的teacher model对student model进行蒸馏。
teacher model是一个在精度表现上优良的模型,而student model往往是精度差一些,但推理速度高的模型。如果要采用这种蒸馏方式,需要注意的是两个Model的网络结构需要相似【因此可以将改进前后的model建立这种关系】。而实现部分最最重要的部分是建立蒸馏的Loss函数。
在目标检测中主要有两个任务,一个是分类,一个是边界的回归,前者的蒸馏是比较容易的,关键是在后者,这也是蒸馏的一个难点。
我们先来看一下SSD代码中的MultiBoxloss部分详解。
MultiBoxloss
SSD中分类loss采用CrossEntropy,边界loss采用平滑L1。具体公式和网络算法原理参考论文,这里不在多说。
loss参数说明
参数说明:
self.use_gpu:是否采用gpu训练
self.num_classes:训练类的数量【在SSD中num_classes是自己的类数量+背景类】
self.threshold:阈值,默认0.5
self.background_label:背景类标签,默认为0
self.encode_target:target编码
self.use_prior_for_matching:利用先眼眶做匹配
self.do_neg_mining:True,负样本挖掘
self.negpos_ratio:负样本比例,设置为3【正负样本比例为1:3】
self.variance :方差
class MultiBoxLoss(nn.Module): """SSD Weighted Loss Function Compute Targets: 1) Produce Confidence Target Indices by matching ground truth boxes with (default) 'priorboxes' that have jaccard index > threshold parameter (default threshold: 0.5). 2) Produce localization target by 'encoding' variance into offsets of ground truth boxes and their matched 'priorboxes'. 3) Hard negative mining to filter the excessive number of negative examples that comes with using a large number of default bounding boxes. (default negative:positive ratio 3:1) Objective Loss: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss weighted by α which is set to 1 by cross val. Args: c: class confidences, l: predicted boxes, g: ground truth boxes N: number of matched default boxes See: https://arxiv.org/pdf/1512.02325.pdf for more details. """ def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, use_gpu=True, negatives_for_hard=100.0): super(MultiBoxLoss, self).__init__() self.use_gpu = use_gpu self.num_classes = num_classes self.threshold = overlap_thresh self.background_label = bkg_label self.encode_target = encode_target self.use_prior_for_matching = prior_for_matching self.do_neg_mining = neg_mining self.negpos_ratio = neg_pos self.neg_overlap = neg_overlap self.negatives_for_hard = negatives_for_hard self.variance = Config['variance']
loss forward部分
predictions:类型为tuple,网络的输出内容,包含:位置预测,分类置信度预测以及prior boxes预测。
predictions[0]的shape为:[batch,8732,4]
predictions[1]的shape为:[batch,8732,num_classes]
predictions[2]的shape为:[8732,4]
注:8732:以输入大小300*300为例,将在6个head部分产生8732个先眼眶.
8732= 38*38*4 + 19*19*6 + 10*10*6 + 5*5*6 + 3*3*6 + 1*1*4
target:包含了标注的数据集真实的boxes坐标以及label信息。是一个列表,列表的长度等于batch的数量,每个列表中的元素shape为[num_objs,5],num_objs表示你当前图像中标注的目标数量,5=boxes信息+label信息。
def forward(self, predictions, targets): """Multibox Loss Args: predictions (tuple): A tuple containing loc preds, conf preds, and prior boxes from SSD net. conf shape: torch.size(batch_size,num_priors,num_classes) loc shape: torch.size(batch_size,num_priors,4) priors shape: torch.size(num_priors,4) pred_t (tuple): teacher's predictions targets (tensor): Ground truth boxes and labels for a batch, shape: [batch_size,num_objs,5] (last idx is the label). """ #--------------------------------------------------# # 取出预测结果的三个值:回归信息,置信度,先验框 #--------------------------------------------------# loc_data, conf_data, priors = predictions
创建两个全零张量用来做先验框和真实框的匹配,这里的num等于batch_size
loc_t = torch.zeros(num, num_priors, 4).type(torch.FloatTensor) conf_t = torch.zeros(num, num_priors).long()
遍历每个batch,idx是batch的索引,truths是获取到的真实值的boxes信息。labels是获取到的当前图像中是什么类。
truths:tensor([[0.2333, 0.2067, 0.6967, 1.0000]], device='cuda:0') labels:tensor([0.], device='cuda:0')
for idx in range(num): # 获得真实框与标签 truths = targets[idx][:, :-1] labels = targets[idx][:, -1] if(len(truths)==0): continue # 获得先验框 defaults = priors #--------------------------------------------------# # 利用真实框和先验框进行匹配。 # 如果真实框和先验框的重合度较高,则认为匹配上了。 # 该先验框用于负责检测出该真实框。 #--------------------------------------------------# match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx)
defaults是先验框,接下来是和真实框进行标签匹配。
标签匹配函数match
传入参数:threshold,truths[真实boxes],defaults[先验框],variance[方差],labels[真实标签],loc_t[前面创建的全零张量,用来存放匹配后的boxes信息], conf_t[用来存储匹配后的置信度分类信息],idx[当前batch的索引 ]。
1.先计算所有先验框和真实框的的重合程度。
box_a是就是上面的truths,box_b是先验框【注意先验框中的boxes形式是center_x,center_y,w,h,需要先转成左上角和右下角的形式】。最终就可以计算出IOU。
def jaccard(box_a, box_b): #-------------------------------------# # 返回的inter的shape为[A,B] # 代表每一个真实框和先验框的交矩形 #-------------------------------------# inter = intersect(box_a, box_b) #-------------------------------------# # 计算先验框和真实框各自的面积 #-------------------------------------# area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] union = area_a + area_b - inter #-------------------------------------# # 每一个真实框和先验框的交并比[A,B] #-------------------------------------# return inter / union
因此得到的overlaps是计算的所有先验框和真实框的iou,shape为[1,8732]。
overlaps = jaccard( truths, point_form(priors) )
接下来是通过max函数获得这8732个先验框中与真实框匹配度最好的框和索引【就相当于可以把这个匹配的最好的认为是ground truth】。
可以得到iou最高的是0.6904,是第8711号先验框。
best_prior_overlap:tensor([[0.6904]], device='cuda:0') best_prior_idx:tensor([[8711]], device='cuda:0')
用于保证每个真实框都有一个先验框与之匹配。
for j in range(best_prior_idx.size(0)): best_truth_idx[best_prior_idx[j]] = j best_truth_overlap.index_fill_(0, best_prior_idx, 2)
将truths扩充成8732.
matches = truths[best_truth_idx]
获取标签
conf = labels[best_truth_idx] + 1
获取背景类,通过设置的iou阈值进行过滤。
conf[best_truth_overlap < threshold] = 0
进行边界框的编码【其实就是将真实框和先验框进行匹配】。
def encode(matched, priors, variances): g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] g_cxcy /= (variances[0] * priors[:, 2:]) g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] g_wh = torch.log(g_wh) / variances[1] return torch.cat([g_cxcy, g_wh], 1)
获得的loc shape为【8732,4】
loc = encode(matches, priors, variances)
将编码后的loc放入前面定义loc_t中,conf也是如此。
获得正样本。
# 所有conf_t>0的地方,代表内部包含物体 pos = conf_t > 0
此时的pos形式如下,shape为【batch,8732】:
tensor([[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]], device='cuda:0')
求和得到每个图像内有多少正样本。这就可以计算出在所有的batch中的所有batch*8732个框中有多少框内包含目标。
num_pos = pos.sum(dim=1, keepdim=True)