目标检测知识蒸馏---以SSD为例【附代码】下

简介: 笔记

loss计算

取出所有的正样本计算loss


获得所有正样本的idx,返回形式是Truth or False.


pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)

通过索引在loc_data[预测的位置]中选择出正样本的loc_p【也就是预测目标的loc】。


loc_p = loc_data[pos_idx].view(-1, 4)

通过正样本的索引在loc_t【groud truth】中进行筛选获得正样本的loc_t。


loc_t = loc_t[pos_idx].view(-1, 4)

计算边界回归loss:


直接调用smooth_l1_loss计算loss【loc_p是预测值,loc_t是真实值】


loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False)

分类loss:


获得网络预测的conf,进行reshape,这就获得了所有batch中预测框内的conf,shape为【batch*8732,num_classes】。


batch_conf = conf_data.view(-1, self.num_classes)

conf_p是预测值【筛选后具有正样本的】,

# 这个地方是在寻找难分类的先验框
        loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
        loss_c = loss_c.view(num, -1)
        # 难分类的先验框不把正样本考虑进去,只考虑难分类的负样本
        loss_c[pos] = 0 
        #--------------------------------------------------#
        #   loss_idx    (num, num_priors)
        #   idx_rank    (num, num_priors)
        #--------------------------------------------------#
        _, loss_idx = loss_c.sort(1, descending=True)
        _, idx_rank = loss_idx.sort(1)
        #--------------------------------------------------#
        #   求和得到每一个图片内部有多少正样本
        #   num_pos     (num, )
        #   neg         (num, num_priors)
        #--------------------------------------------------#
        num_pos = pos.long().sum(1, keepdim=True)
        # 限制负样本数量
        num_neg = torch.clamp(self.negpos_ratio * num_pos, max = pos.size(1) - 1)
        num_neg[num_neg.eq(0)] =  self.negatives_for_hard
        neg = idx_rank < num_neg.expand_as(idx_rank)
        #--------------------------------------------------#
        #   求和得到每一个图片内部有多少正样本
        #   pos_idx   (num, num_priors, num_classes)
        #   neg_idx   (num, num_priors, num_classes)
        #--------------------------------------------------#
        pos_idx = pos.unsqueeze(2).expand_as(conf_data)
        neg_idx = neg.unsqueeze(2).expand_as(conf_data)
        # 选取出用于训练的正样本与负样本,计算loss
        conf_p = conf_data[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
        targets_weighted = conf_t[(pos + neg).gt(0)]
        loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False)

最后总的Loss为:

loss:8.0996


MultiBoxloss_KD


在原来的loss基础上加入了soft-target loss部分。

class MultiBoxLoss_KD(nn.Module):
    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True, negatives_for_hard=100.0,neg_w=1.5, pos_w=1.0, Temp=1., reg_m=0.):
        super(MultiBoxLoss_KD, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes  # 21
        self.threshold = overlap_thresh  # 0.5
        self.background_label = bkg_label  # 0
        self.encode_target = encode_target  # False
        self.use_prior_for_matching = prior_for_matching  # True
        self.do_neg_mining = neg_mining  # True
        self.negpos_ratio = neg_pos  # 3
        self.neg_overlap = neg_overlap  # 0.5
        self.variance = Config['variance']
        self.negatives_for_hard = negatives_for_hard
        # soft-target loss
        self.neg_w = neg_w # 负样本(背景)权重
        self.pos_w = pos_w # 正样本权重
        self.Temp = Temp # 温度
        self.reg_m = reg_m

在forward部分传入参数为predictions[student的输出,pred_t为teacher的输出,targets是真实值]

def forward(self, predictions, pred_t, targets):
        """Multibox Loss
        Args:
            predictions (tuple): A tuple containing loc preds, conf preds,
            and prior boxes from SSD net.
                conf shape: torch.size(batch_size,num_priors,num_classes)
                loc shape: torch.size(batch_size,num_priors,4)
                priors shape: torch.size(num_priors,4)
            pred_t (tuple): teacher's predictions
            targets (tensor): Ground truth boxes and labels for a batch,
                shape: [batch_size,num_objs,5] (last idx is the label).
        """

kd for loc regression

这里的loc regression采用的是l2 loss.

        # teach1  这里的s指student,t指真实值
        loc_teach1_p = loc_teach1[pos_idx].view(-1, 4)  # loc_teach1_p = tensor<(3, 4), float32, cuda:0, grad>
        l2_dis_s = (loc_p - loc_t).pow(2).sum(1)  # Σ(loc_p-loc_t)² 计算学生L2 loss,(学生预测loc-真实标签)²  sum(1)求行和  l2_dis_s = tensor<(3,), float32, cuda:0, grad>
        l2_dis_s_m = l2_dis_s + self.reg_m  # l2_dis_s_m = tensor<(3,), float32, cuda:0, grad>
        l2_dis_t = (loc_teach1_p - loc_t).pow(2).sum(1)  # L2 loss:(老师loc预测值-真实标签)²并求和  l2_dis_t = tensor<(3,), float32, cuda:0, grad>
        l2_num = l2_dis_s_m > l2_dis_t  # 判断学生位置回归与真实reg距离 和 老师位置回归与真实标签距离 的大小  l2_num = tensor<(3,), bool, cuda:0>
        l2_loss_teach1 = l2_dis_s[l2_num].sum()  # 当学生大于老师 Lb(Rs,Rt,y)=Σ(loc_p-loc_t)²,否则为0 Lb表示文章定义的teacher bounded regression loss
                                                 # l2_loss_teach1 = tensor<(), float32, cuda:0, grad> 取出l2_num=True的
        l2_loss = l2_loss_teach1  # l2_loss = tensor<(), float32, cuda:0, grad>

kd for conf regression

conf_p是预测值,ps是student的分类预测,pt是teacher的分类预测,计算两者loss。

        # soft loss for Knowledge Distillation
        # teach1
        conf_p_teach = conf_teach1[(pos_idx + neg_idx).gt(0)].view(-1, self.num_classes)
        pt = F.softmax(conf_p_teach / self.Temp, dim=1)
        if self.neg_w > 1.:
            ps = F.softmax(conf_p / self.Temp, dim=1)
            soft_loss1 = KL_div(pt, ps, self.pos_w, self.neg_w) * (self.Temp ** 2)
        else:
            ps = F.log_softmax(conf_p / self.Temp, dim=1)
            soft_loss1 = nn.KLDivLoss(size_average=False)(ps, pt) * (self.Temp ** 2)
        soft_loss = soft_loss1

最后返回有4个loss,soft_loss是分类kd loss,l2loss是loc 蒸馏loss,loss_c, loss_l均为hard loss中的student自己的loss。

        loss_l /= N
        loss_c /= N
        l2_loss /= N
        soft_loss /= N
        return soft_loss, l2_loss, loss_c, loss_l

然后可以根据自己的情况给不同的loss分配不同的权重进行训练。

            soft_loss, l2_loss, loss_c, loss_l = criterion(out_student, teacher_out, targets)  # KD损失函数
            # loss_l, loss_c = criterion1(out_student, targets) # criterion1原损失函数
            loss = (0.3 * soft_loss + 0.7 * loss_c) + (0.5 * l2_loss + loss_l)

训练如下(这里为了方便演示我这里只放了100张图片训练):

Epoch 9/10: 100%|██████████| 54/54 [00:19<00:00,  2.80it/s, conf_loss=2.37, loc_loss=0.752, lr=7.16e-5]
Start Teacher Validation
Epoch 9/10: 100%|██████████| 6/6 [00:02<00:00,  2.01it/s, conf_loss=2.24, loc_loss=0.669, lr=7.16e-5]
Finish Teacher Validation
Epoch:9/10
Total Loss: 3.0658 || Val Loss: 2.4932 
Saving state, iter: 9
Start Teacher Train
Epoch 10/10: 100%|██████████| 54/54 [00:19<00:00,  2.79it/s, conf_loss=2.28, loc_loss=0.73, lr=6.59e-5]
Epoch 10/10:   0%|          | 0/6 [00:00<?, ?it/s<class 'dict'>]Start Teacher Validation
Epoch 10/10: 100%|██████████| 6/6 [00:03<00:00,  1.97it/s, conf_loss=2.14, loc_loss=0.691, lr=6.59e-5]
Finish Teacher Validation
Epoch:10/10
Total Loss: 2.9564 || Val Loss: 2.4282 
Saving state, iter: 10
开始蒸馏训练
Loading weights into state dict...
Finished!
Epoch 1/2:   0%|          | 0/54 [00:00<?, ?it/s<class 'dict'>]Start teacher2student_KD Train
Epoch 1/2: 100%|██████████| 54/54 [00:20<00:00,  2.60it/s, conf_loss=3.19, l2_loss=8.29, loc_loss=2.91, lr=0.0005, soft_loss=3.9]
Start Teacher2student_KD Validation
Epoch 1/2: 100%|██████████| 6/6 [00:02<00:00,  2.12it/s, conf_loss=2.75, loc_loss=2.56, lr=0.0005]
Finish teacher2student_KD Validation
Epoch:1/2
Total Loss: 17.9524 || Val Loss: 4.5457 
Saving state, iter: 1
Start teacher2student_KD Train
Epoch 2/2: 100%|██████████| 54/54 [00:20<00:00,  2.58it/s, conf_loss=2.99, l2_loss=6.52, loc_loss=2.53, lr=0.00046, soft_loss=3.41]
Start Teacher2student_KD Validation
Epoch 2/2: 100%|██████████| 6/6 [00:02<00:00,  2.14it/s, conf_loss=2.65, loc_loss=2.68, lr=0.00046]
Finish teacher2student_KD Validation
Epoch:2/2
Total Loss: 15.1709 || Val Loss: 4.5685 
Saving state, iter: 2
Epoch 3/4:   0%|          | 0/54 [00:00<?, ?it/s<class 'dict'>]Start teacher2student_KD Train
Epoch 3/4: 100%|██████████| 54/54 [00:25<00:00,  2.12it/s, conf_loss=2.66, l2_loss=6.92, loc_loss=2.63, lr=0.0001, soft_loss=3.05]
Epoch 3/4:   0%|          | 0/6 [00:00<?, ?it/s<class 'dict'>]Start Teacher2student_KD Validation
Epoch 3/4: 100%|██████████| 6/6 [00:02<00:00,  2.21it/s, conf_loss=2.39, loc_loss=2.66, lr=0.0001]
Finish teacher2student_KD Validation
Epoch:3/4
Total Loss: 14.9715 || Val Loss: 4.3286 
Saving state, iter: 3
Start teacher2student_KD Train
Epoch 4/4: 100%|██████████| 54/54 [00:25<00:00,  2.12it/s, conf_loss=2.44, l2_loss=6.46, loc_loss=2.54, lr=9.2e-5, soft_loss=2.84]
Start Teacher2student_KD Validation
Epoch 4/4: 100%|██████████| 6/6 [00:02<00:00,  2.15it/s, conf_loss=2.38, loc_loss=2.57, lr=9.2e-5]
Finish teacher2student_KD Validation
Epoch:4/4
Total Loss: 14.0245 || Val Loss: 4.2435 
Saving state, iter: 4

注:离线蒸馏训练对于teacher model也是有要求的,我这里的teacher model只是随便在原model的基础上改了一下训练而已,我这里仅仅是演示一下,具体的改进等需要自己去不断尝试。因此kd的好坏是取决于两个模型的。


大家也可以尝试其他的蒸馏方式,有问题可评论留言~~欢迎支持


目录
相关文章
|
1月前
|
缓存 并行计算 C++
实践教程|旋转目标检测模型-TensorRT 部署(C++)
实践教程|旋转目标检测模型-TensorRT 部署(C++)
68 0
|
22天前
|
机器学习/深度学习 存储 测试技术
【YOLOv8改进】 YOLOv8 更换骨干网络之 GhostNet :通过低成本操作获得更多特征 (论文笔记+引入代码).md
YOLO目标检测专栏探讨了卷积神经网络的创新改进,如Ghost模块,它通过低成本运算生成更多特征图,降低资源消耗,适用于嵌入式设备。GhostNet利用Ghost模块实现轻量级架构,性能超越MobileNetV3。此外,文章还介绍了SegNeXt,一个高效卷积注意力网络,提升语义分割性能,参数少但效果优于EfficientNet-L2。专栏提供YOLO相关基础解析、改进方法和实战案例。
|
9天前
|
测试技术 计算机视觉
【YOLOv8性能对比试验】YOLOv8n/s/m/l/x不同模型尺寸大小的实验结果对比及结论参考
【YOLOv8性能对比试验】YOLOv8n/s/m/l/x不同模型尺寸大小的实验结果对比及结论参考
|
9天前
|
机器学习/深度学习 计算机视觉
【YOLO性能对比试验】YOLOv9c/v8n/v6n/v5n的训练结果对比及结论参考
【YOLO性能对比试验】YOLOv9c/v8n/v6n/v5n的训练结果对比及结论参考
|
1月前
|
机器学习/深度学习 编解码 Unix
超分数据集概述和超分经典网络模型总结
超分数据集概述和超分经典网络模型总结
49 1
|
1月前
|
机器学习/深度学习 人工智能 自然语言处理
RT-DETR原理与简介(干翻YOLO的最新目标检测项目)
RT-DETR原理与简介(干翻YOLO的最新目标检测项目)
304 1
|
10月前
|
机器学习/深度学习 编解码 计算机视觉
【轻量化网络系列(2)】MobileNetV2论文超详细解读(翻译 +学习笔记+代码实现)
【轻量化网络系列(2)】MobileNetV2论文超详细解读(翻译 +学习笔记+代码实现)
797 0
【轻量化网络系列(2)】MobileNetV2论文超详细解读(翻译 +学习笔记+代码实现)
|
10月前
|
机器学习/深度学习 计算机视觉 文件存储
【轻量化网络系列(3)】MobileNetV3论文超详细解读(翻译 +学习笔记+代码实现)
【轻量化网络系列(3)】MobileNetV3论文超详细解读(翻译 +学习笔记+代码实现)
2318 0
【轻量化网络系列(3)】MobileNetV3论文超详细解读(翻译 +学习笔记+代码实现)
|
10月前
|
计算机视觉 编解码 机器学习/深度学习
【轻量化网络系列(1)】MobileNetV1论文超详细解读(翻译 +学习笔记+代码实现)
【轻量化网络系列(1)】MobileNetV1论文超详细解读(翻译 +学习笔记+代码实现)
492 0
【轻量化网络系列(1)】MobileNetV1论文超详细解读(翻译 +学习笔记+代码实现)
|
机器学习/深度学习 算法 数据可视化
详细解读GraphFPN | 如何用图模型提升目标检测模型性能?
详细解读GraphFPN | 如何用图模型提升目标检测模型性能?
151 0