目标检测知识蒸馏---以SSD为例【附代码】上

简介: 笔记

在上一篇文章中有讲解以分类网络为例的知识蒸馏【分类网络知识蒸馏】,这篇文章将会针对目标检测网络进行蒸馏。


知识蒸馏是一种不改变网络结构模型压缩方法。这里的压缩需要和量化与剪枝进行区分,并不是严格意义上的压缩。这里将要讲的蒸馏是离线式蒸馏中的逻辑蒸馏【特征部分蒸馏以后会讲】,也是一种常用的方法,他是将已经训练好的teacher model对student model进行蒸馏。


teacher model是一个在精度表现上优良的模型,而student model往往是精度差一些,但推理速度高的模型。如果要采用这种蒸馏方式,需要注意的是两个Model的网络结构需要相似【因此可以将改进前后的model建立这种关系】。而实现部分最最重要的部分是建立蒸馏的Loss函数。


在目标检测中主要有两个任务,一个是分类,一个是边界的回归,前者的蒸馏是比较容易的,关键是在后者,这也是蒸馏的一个难点。


我们先来看一下SSD代码中的MultiBoxloss部分详解。


MultiBoxloss


SSD中分类loss采用CrossEntropy,边界loss采用平滑L1。具体公式和网络算法原理参考论文,这里不在多说。


loss参数说明

参数说明:


self.use_gpu:是否采用gpu训练


self.num_classes:训练类的数量【在SSD中num_classes是自己的类数量+背景类】


self.threshold:阈值,默认0.5


self.background_label:背景类标签,默认为0


self.encode_target:target编码


self.use_prior_for_matching:利用先眼眶做匹配


self.do_neg_mining:True,负样本挖掘


self.negpos_ratio:负样本比例,设置为3【正负样本比例为1:3】


self.variance :方差

class MultiBoxLoss(nn.Module):
    """SSD Weighted Loss Function
        Compute Targets:
            1) Produce Confidence Target Indices by matching  ground truth boxes
               with (default) 'priorboxes' that have jaccard index > threshold parameter
               (default threshold: 0.5).
            2) Produce localization target by 'encoding' variance into offsets of ground
               truth boxes and their matched  'priorboxes'.
            3) Hard negative mining to filter the excessive number of negative examples
               that comes with using a large number of default bounding boxes.
               (default negative:positive ratio 3:1)
        Objective Loss:
            L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
            Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
            weighted by α which is set to 1 by cross val.
            Args:
                c: class confidences,
                l: predicted boxes,
                g: ground truth boxes
                N: number of matched default boxes
            See: https://arxiv.org/pdf/1512.02325.pdf for more details.
        """
    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True, negatives_for_hard=100.0):
        super(MultiBoxLoss, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes
        self.threshold = overlap_thresh
        self.background_label = bkg_label
        self.encode_target = encode_target
        self.use_prior_for_matching = prior_for_matching
        self.do_neg_mining = neg_mining
        self.negpos_ratio = neg_pos
        self.neg_overlap = neg_overlap
        self.negatives_for_hard = negatives_for_hard
        self.variance = Config['variance']

loss forward部分

predictions:类型为tuple,网络的输出内容,包含:位置预测,分类置信度预测以及prior boxes预测。


predictions[0]的shape为:[batch,8732,4]


predictions[1]的shape为:[batch,8732,num_classes]


predictions[2]的shape为:[8732,4]


注:8732:以输入大小300*300为例,将在6个head部分产生8732个先眼眶.


8732= 38*38*4 + 19*19*6 + 10*10*6 + 5*5*6 + 3*3*6 + 1*1*4


target:包含了标注的数据集真实的boxes坐标以及label信息。是一个列表,列表的长度等于batch的数量,每个列表中的元素shape为[num_objs,5],num_objs表示你当前图像中标注的目标数量,5=boxes信息+label信息。

    def forward(self, predictions, targets):
        """Multibox Loss
                Args:
                    predictions (tuple): A tuple containing loc preds, conf preds,
                    and prior boxes from SSD net.
                        conf shape: torch.size(batch_size,num_priors,num_classes)
                        loc shape: torch.size(batch_size,num_priors,4)
                        priors shape: torch.size(num_priors,4)
                    pred_t (tuple): teacher's predictions
                    targets (tensor): Ground truth boxes and labels for a batch,
                        shape: [batch_size,num_objs,5] (last idx is the label).
                """
        #--------------------------------------------------#
        #   取出预测结果的三个值:回归信息,置信度,先验框
        #--------------------------------------------------#
        loc_data, conf_data, priors = predictions

创建两个全零张量用来做先验框和真实框的匹配,这里的num等于batch_size

loc_t = torch.zeros(num, num_priors, 4).type(torch.FloatTensor)
conf_t = torch.zeros(num, num_priors).long()

遍历每个batch,idx是batch的索引,truths是获取到的真实值的boxes信息。labels是获取到的当前图像中是什么类。


truths:tensor([[0.2333, 0.2067, 0.6967, 1.0000]], device='cuda:0')
labels:tensor([0.], device='cuda:0')
        for idx in range(num):
            # 获得真实框与标签
            truths = targets[idx][:, :-1]
            labels = targets[idx][:, -1]
            if(len(truths)==0):
                continue
            # 获得先验框
            defaults = priors
            #--------------------------------------------------#
            #   利用真实框和先验框进行匹配。
            #   如果真实框和先验框的重合度较高,则认为匹配上了。
            #   该先验框用于负责检测出该真实框。
            #--------------------------------------------------#
            match(self.threshold, truths, defaults, self.variance, labels, loc_t, conf_t, idx)

defaults是先验框,接下来是和真实框进行标签匹配。


标签匹配函数match

       传入参数:threshold,truths[真实boxes],defaults[先验框],variance[方差],labels[真实标签],loc_t[前面创建的全零张量,用来存放匹配后的boxes信息], conf_t[用来存储匹配后的置信度分类信息],idx[当前batch的索引 ]。


1.先计算所有先验框和真实框的的重合程度。


       box_a是就是上面的truths,box_b是先验框【注意先验框中的boxes形式是center_x,center_y,w,h,需要先转成左上角和右下角的形式】。最终就可以计算出IOU。

def jaccard(box_a, box_b):
    #-------------------------------------#
    #   返回的inter的shape为[A,B]
    #   代表每一个真实框和先验框的交矩形
    #-------------------------------------#
    inter = intersect(box_a, box_b)
    #-------------------------------------#
    #   计算先验框和真实框各自的面积
    #-------------------------------------#
    area_a = ((box_a[:, 2]-box_a[:, 0]) *
              (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter)  # [A,B]
    area_b = ((box_b[:, 2]-box_b[:, 0]) *
              (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter)  # [A,B]
    union = area_a + area_b - inter
    #-------------------------------------#
    #   每一个真实框和先验框的交并比[A,B]
    #-------------------------------------#
    return inter / union

因此得到的overlaps是计算的所有先验框和真实框的iou,shape为[1,8732]。


overlaps = jaccard(
        truths,
        point_form(priors)
    )

接下来是通过max函数获得这8732个先验框中与真实框匹配度最好的框和索引【就相当于可以把这个匹配的最好的认为是ground truth】。


可以得到iou最高的是0.6904,是第8711号先验框。


best_prior_overlap:tensor([[0.6904]], device='cuda:0')
best_prior_idx:tensor([[8711]], device='cuda:0')


用于保证每个真实框都有一个先验框与之匹配。

for j in range(best_prior_idx.size(0)):
        best_truth_idx[best_prior_idx[j]] = j
    best_truth_overlap.index_fill_(0, best_prior_idx, 2)

将truths扩充成8732.


matches = truths[best_truth_idx]

获取标签


conf = labels[best_truth_idx] + 1

获取背景类,通过设置的iou阈值进行过滤。


conf[best_truth_overlap < threshold] = 0

进行边界框的编码【其实就是将真实框和先验框进行匹配】。


def encode(matched, priors, variances):
    g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2]
    g_cxcy /= (variances[0] * priors[:, 2:])
    g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
    g_wh = torch.log(g_wh) / variances[1]
    return torch.cat([g_cxcy, g_wh], 1)

获得的loc shape为【8732,4】

loc = encode(matches, priors, variances)

将编码后的loc放入前面定义loc_t中,conf也是如此。

获得正样本。

# 所有conf_t>0的地方,代表内部包含物体
pos = conf_t > 0

此时的pos形式如下,shape为【batch,8732】:


tensor([[False, False, False,  ..., False, False,  True],

       [False, False, False,  ..., False, False, False]], device='cuda:0')


求和得到每个图像内有多少正样本。这就可以计算出在所有的batch中的所有batch*8732个框中有多少框内包含目标。


num_pos = pos.sum(dim=1, keepdim=True)


相关实践学习
在云上部署ChatGLM2-6B大模型(GPU版)
ChatGLM2-6B是由智谱AI及清华KEG实验室于2023年6月发布的中英双语对话开源大模型。通过本实验,可以学习如何配置AIGC开发环境,如何部署ChatGLM2-6B大模型。
目录
相关文章
|
机器学习/深度学习 存储 算法
【轻量化网络】概述网络进行轻量化处理中的:剪枝、蒸馏、量化
【轻量化网络】概述网络进行轻量化处理中的:剪枝、蒸馏、量化
561 0
|
前端开发 rax Python
Open3d系列 | 2. Open3d实现点云数据增强
Open3d系列 | 2. Open3d实现点云数据增强
3462 1
Open3d系列 | 2. Open3d实现点云数据增强
|
机器学习/深度学习 存储 编解码
Open3d系列 | 3. Open3d实现点云上采样、点云聚类、点云分割以及点云重建
Open3d系列 | 3. Open3d实现点云上采样、点云聚类、点云分割以及点云重建
13614 1
Open3d系列 | 3. Open3d实现点云上采样、点云聚类、点云分割以及点云重建
|
9月前
|
数据采集 搜索推荐 数据安全/隐私保护
Referer头部在网站反爬虫技术中的运用
Referer头部在网站反爬虫技术中的运用
|
7月前
|
机器学习/深度学习 编解码 计算机视觉
YOLOv11改进策略【注意力机制篇】| CVPRW-2024 分层互补注意力混合层 H-RAMi 针对低质量图像的特征提取模块
YOLOv11改进策略【注意力机制篇】| CVPRW-2024 分层互补注意力混合层 H-RAMi 针对低质量图像的特征提取模块
226 1
YOLOv11改进策略【注意力机制篇】| CVPRW-2024 分层互补注意力混合层 H-RAMi 针对低质量图像的特征提取模块
|
10月前
|
存储 Oracle NoSQL
【赵渝强老师】Oracle的体系架构
Oracle数据库的核心在于其体系架构,主要包括数据库与实例、存储结构、进程结构和内存结构。数据库由物理文件组成,实例则是内存和进程的组合。存储结构分为逻辑和物理两部分,进程结构涉及多个后台进程如SMON、PMON、DBWn等,内存结构则包含SGA和PGA。掌握这些知识有助于更好地管理和优化Oracle数据库。
323 7
|
数据可视化
Open3d Point cloud outlier removal 点云异常值移除
Open3d Point cloud outlier removal 点云异常值移除
407 1
|
固态存储 算法 计算机视觉
SSD算法1
8月更文挑战第9天
一文教你学会keil软件仿真
一文教你学会keil软件仿真
1942 1
|
编解码 监控 安全
GB/T28181规范扫盲和使用场景探讨
GB28181(GB/T 28181-2022)是中国国家标准,规定了安全防范视频监控联网系统的信息传输、交换、控制技术要求。此标准支持设备接入、音视频传输及控制指令交互等功能,适用于各类监控设备如执法记录仪和移动监控系统。技术实现涉及协议栈构建、音视频编码及数据传输等环节。广泛应用在执法记录、移动监控和铁路巡检等领域。例如,海康威视iSecure Center和萤石云平台均支持GB28181协议,实现设备管理和视频传输。此外,大牛直播SDK推出的SmartGBD为Android终端提供了便捷的GB28181接入解决方案,支持多种数据类型接入,增强了设备的互操作性。
1013 0

热门文章

最新文章