深度学习原理篇 第四章:DEIT

简介: 简要介绍DEIT的原理和代码实现。

参考教程:
https://arxiv.org/pdf/2012.12877.pdf
https://github.com/facebookresearch/deit

概述

在之前的章节中提到过,VIT模型训练的一个问题是对数据的要求比较高,因为基于transformer的模型相对于基于卷积的模型,更加flexible。卷积的模型有着预设好的感受野,而transformer的模型需要自己去学习哪部分更加重要,因此训练上也更困难。

在这种情况下,想独自训练一个效果比较好的transformer模型是很困难的,你很难准备大几百万的数据集用于训练。这也给论文复现带来了难度,你看别人的模型效果好,你想去学习,但是没有资源训练出相当的模型。

DEIT提出了一种基于token的蒸馏方法,使用和训练卷积网络差不多的时间,只用imagenet作为训练集,就实现了非常不错的效果。

总的来说,DEIT做出了以下贡献(这一段直接翻译的论文原文):

  • 证明了不包含卷积层的网络在只是用ImageNet数据的情况下也能取得很有竞争力的表现。
  • 提出了一种基于token的蒸馏方法,并且这个方法的效果明显超过了普通的蒸馏方法。
  • 有趣的是,基于transformer的模型以convnet为老师时表现的比以transformer为老师时要好。
  • 他们的基于imagenet预训练的模型应用于其它下游任务时效果也很不错。

Knowledge Distillation

在这里补充一点知识蒸馏相关的内容。

知识蒸馏简单来说呢,就是把我们想要训练的模型当作“学生”模型,在向我们的hard label,也就是ground truth的结果靠近的同时,也让它向一个“老师”模型(一般是一个效果更好的、体量更大的模型)输出的soft label靠近。

比较简单的方法就是直接让学生模型的输出logits去拟合老师模型的输出logits,复杂一点的会增加层与层之间的拟合。

下面的代码就来自一个比较早期的repohttps://github.com/haitongli/knowledge-distillation-pytorch/tree/master

可以看到KD_loss明显有两部分组成。

T = params.temperature
KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + F.cross_entropy(outputs, labels) * (1. - alpha)

第一个部分就是我们的软目标损失,使用KLD散度计算输出的logits与老师模型输出的logits的差距,T在这里是一个温度系数,T越大得到的概率分布就越平滑。第二个部分就是我们的硬目标损失,也就是输出与label的交叉熵损失。

DEIT

base model: VIT

首先来重新介绍一下DEIT方法中使用的模型框架,其实也就是复习了一遍VIT。

transformer block

DEIT的工作是在VIT模型的基础上完成的。使用固定大小的RGB图像作为输入,这个图像被拆解成N个大小为16*16的小patch,N的大小一般是14*14。也就是说默认图像的大小是224*224。

每个patch都会被处理成一个指定维度的token。在之前的章节中我们介绍过这里有两个常用做法,再次复述一下。

第一种做法是使用reshape之后,使用全连接层完成维度的变化。

self.proj = Rearrange('b c (h p) (w p ) -> b (h w) (p1 p2 c)', p = patch_size)
self.linear = nn.Linear(patch_size * patch_size * in_c, embed_dim)

第二种做法是直接使用卷积。

self.proj = nn.Conv2d(in_c, embed_dim, kernel_size = patch_size, stride=patch_size)

目前来说第二种方法是更常用的。

然后再给得到的embedding加上一个class_token和一个position_embeddings。就构成了一个完整的输入。

class token

VIT中模仿BERT的做法,在得到的patch embedding上concat了一个可训练的class token。这个class token也会贯穿整个网络,并且最终用于分类。它相当于起到了串联所有patch_embedding的作用,它包含的也是一个整体的信息。

也就是说在整个过程中,transformer一共使用了N+1个token,但是只有第一个class token被用来进行结果的预测。

position embedding

已知transformer中最重要的结构就是MSA,在MSA中会根据你的输入计算三个vector,分别是Query, Key, Value。并使用Q和K的内积计算attention。

我们直接看一下源码,可以看到这个qkv是通过全连接得到的,它完成的是从embed_dim到embed_dim的映射,这个过程是和embed的数量无关的。

self.qkv = nn.Linear(emb_size, emb_size*3)

所以一个在low-resolution的图像上训练的模型,也是很容易用在high-resolution的图像上的。只要使用一样的patch_size就可以。

这时候聪明的你可能会发现一个问题,patch_size大小一样,在high-resolution图像上得到的patch的数量肯定比low-resolution要多呀。那么position_embedding是会受到影响的,position_embedding的大小是和我们的数量以及embed_size都有关系的。

 self.positions = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

原VIT论文中的做法是这样的

We therefore perform 2D interpolation of the pre-trained position embeddings, according to their location in the original image.

Distillation through attention

作者在论文中对蒸馏的部分进行了比较详细的介绍。

soft distillation
软蒸馏就是上面介绍的,用学生模型的logits向老师模型的logits学习,两者的差距使用KL散度来衡量。

hard distillation
硬蒸馏是将老师模型预测的结果也作为真实的标签,让你的学生模型也去学习这个标签。
$$ L^{hardDistill}_{global} = (1-\epsilon)\times\frac{1}{2}L_{CE}(\psi(Z_s),y) + \epsilon\times\frac{1}{2}L_{CE}(\psi(Z_s),yt) $$
这种实现方法也更简单方便。老师模型预测的label和ground truth的label扮演一样的角色。

image.png

Distillation token
上图介绍了DEIT是如何进行token的蒸馏的。它们在原有的patch embedding的基础上(patch and class token)新增了一个额外的token,称为distillation token
distillation token和class token一样,在整个训练过程中和别的token进行交互,并在最后一层输出。
class_token的分类结果向ground_truth靠齐,distillation_token的分类结果向我们的teacher靠齐。

整体的原理还是很简单的,可以看作class_token和distillation_token各学各的,在最后测试的时候,两个token是合在一起使用的。

代码实现

DistilledVisionTransformer

参考的是这里的源码:https://github.com/facebookresearch/deit/blob/main/models.py

我们首先来看一下这个DEIT的类。

class DistilledVisionTransformer(VisionTransformer):

它是直接继承的VisionTransformer的类,并在此基础上进行了一些修改,这个修改也没有很大,比较好理解。

init()

首先,它增加了一个dist_token,这个token和class_token的大小是完全一样的,用一样的代码就可以定义。

self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None

然后它的position_embedding和之前不一样了。在不使用蒸馏的时候position_embedding的长度 = num_patch + 1 (class_token)。现在增加了一个新token,所以它的长度也增加了1,变成了num_patch + 2。

self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))

此外,除了原有的分类头外,现在增加了一个新的蒸馏头,用来预测distillation_token的结果。这个部分代码和之前的分类头也是一样的。

self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
 self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

forward()

在模型的forward中,之前只有一个输出,现在变成了两个。整体的流程是没有什么变化的。

在之前的章节中我们梳理过VIT的流程。

  1. 输入img,获得patch,并转成embedding的形式。
  2. 增加cls embedding和position embedding。
  3. 进入transformer encoder构成的blocks。每个block由两部分组成:
    1. multi-head attention
    2. mlp
  4. 进入mlp分类头,输出结果。

在DEIT中增加了distillation_token,所以流程变为了:

  1. 输入img,获得patch,并转成embedding的形式。
  2. 增加cls embedding和dist embedding和position embedding。
  3. 进入transformer encoder构成的blocks。
  4. cls token进入mlp分类头,dist token进入另一个分类头

第二点主要是输入的维度发生了变化,对整个训练流程是没有影响的。最后一点也不过是分开了两个输出。

embedding

在原版VIT中。

if self.cls_token is not None:
                x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
            x = x + self.pos_embed

在DIET中。

cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        dist_token = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)

logits

在原版VIT中。

def forward_head(self, x, pre_logits: bool = False):
        # 这里的x是self.forward_features的结果。
        if self.global_pool:
            x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
        x = self.fc_norm(x)
        return x if pre_logits else self.head(x)

在DEIT中。
如果是训练中使用,两个结果分开输出,因为要分别计算loss。如果是在inference中,则使用两个输出融合的结果。

 def forward(self, x):
        x, x_dist = self.forward_features(x)
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        if self.training:
            return x, x_dist
        else:
            # during inference, return the average of both classifier predictions
            return (x + x_dist) / 2

distillation loss

除了模型代码的改动外,DEIT中使用的loss也和之前不一样。
我们先来看一下loss的这个类。

init()

class DistillationLoss(torch.nn.Module):
    """
    This module wraps a standard criterion and adds an extra knowledge distillation loss by
    taking a teacher model prediction and using it as additional supervision.
    """
    def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
                 distillation_type: str, alpha: float, tau: float):
        super().__init__()
        self.base_criterion = base_criterion
        self.teacher_model = teacher_model
        assert distillation_type in ['none', 'soft', 'hard']
        self.distillation_type = distillation_type
        self.alpha = alpha
        self.tau = tau

这里传入的base_criterion是你打算用来计算你的分类损失的loss,也就是你的class_head预测的结果和你的图像类别的ground_truth的loss。

    if mixup_active:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif args.smoothing:
        criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    if args.bce_loss:
        criterion = torch.nn.BCEWithLogitsLoss()

第二个参数teacher model是你想要学习的老师模型,因为我们只用这个模型做预测,不用它参与训练,所以要注意使用

teacher_model.eval()

第三个参数distillation_type是让你选择你先用软标签还是硬标签的方法。
第四个参数alpha用于分类损失和蒸馏损失的权重分配。
第四个参数tau就是温度系数,在软标签才会用到。

forward()

def forward(self, inputs, outputs, labels)

损失函数forward的部分的输入有三个,第一个input是我们的原始输入,它会被送入teacher_model中用于计算teacher_model的输出。第二个outputs是我们的学生模型的输出结果,它实际上包括了output(head的输出)和output_kd(dist_head)的输出。第三个labels就是我们的ground truth。

我们的分类损失直接用self.base_criterion进行计算。

base_loss = self.base_criterion(outputs, labels)

蒸馏损失按照你选择的distillation_type可以分为两类:soft和hard。其实还有一个选项是None,这种情况下不使用蒸馏损失。

teacher_outputs = self.teacher_model(inputs)

假如你使用软损失。那么就是用你的dist_head的logits和teacher_model的logits进行比较。在计算中还是使用KL散度。并且这里还会用到我们的温度系数tau。

T = self.tau
distillation_loss = F.kl_div(F.log_softmax(outputs_kd/T,dim=1), F.log_softmax(teacher_outputs/T, dim=1),reduction='sum',log_target=True)*(T*T)/outputs_kd.numel()
# We provide the teacher's targets in log probability because we use log_target=True

假如你使用的是硬损失。那么就是用你的dist_head的logits和teacher_model输出的标签进行比较。

distillation_loss = F.cross_entropy(outputs_kd, teacher_outputs.argmax(dim=1))

最终输出的loss用alpha这个参数平衡了权重。

loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
相关文章
|
6天前
|
机器学习/深度学习 自然语言处理 自动驾驶
深度学习的工作原理
【5月更文挑战第3天】深度学习的工作原理
24 7
|
6天前
|
机器学习/深度学习 算法 vr&ar
揭秘深度学习中的对抗性网络:原理、挑战与未来
【4月更文挑战第17天】 在深度学习领域,对抗性网络(GANs)以其独特的博弈论框架重新定义了生成模型的边界。本文将深入探讨GANs的核心原理,包括其架构、训练过程以及判别器与生成器之间的动态交互。此外,我们还将审视当前GANs面临的主要挑战,如模式崩溃、训练不稳定性以及评估指标的选择问题,并探讨针对这些挑战的解决方案。最后,文章将展望GANs在多个领域的应用前景,从图像合成到增强现实,再到医疗影像分析,揭示其在推动技术创新方面的巨大潜力。
27 7
|
7天前
|
机器学习/深度学习 算法 计算机视觉
基于深度学习的停车位关键点检测系统(代码+原理)
基于深度学习的停车位关键点检测系统(代码+原理)
164 0
|
7天前
|
机器学习/深度学习 算法 算法框架/工具
基于深度学习的交通标志检测和识别(从原理到环境配置/代码运行)
基于深度学习的交通标志检测和识别(从原理到环境配置/代码运行)
277 0
|
6天前
|
机器学习/深度学习 人工智能 自然语言处理
深度学习疆界:探索基本原理与算法,揭秘应用力量,展望未来发展与智能交互的新纪元
深度学习疆界:探索基本原理与算法,揭秘应用力量,展望未来发展与智能交互的新纪元
39 0
|
7天前
|
机器学习/深度学习 数据采集 算法
【深度学习基础】反向传播BP算法原理详解及实战演示(附源码)
【深度学习基础】反向传播BP算法原理详解及实战演示(附源码)
79 0
|
7月前
|
机器学习/深度学习 数据采集 人工智能
朴素贝叶斯深度解码:从原理到深度学习应用
朴素贝叶斯深度解码:从原理到深度学习应用
153 1
|
9月前
|
机器学习/深度学习 PyTorch 区块链
深度学习原理篇 第十章:Pix2Seq
简要介绍pix2seq的原理和代码实现。
250 1
|
9月前
|
机器学习/深度学习 Shell 数据处理
深度学习原理篇 第九章:UP-DETR
简要介绍UP-DETR的原理和代码实现。
132 0
|
9月前
|
机器学习/深度学习 数据采集 编解码
深度学习原理篇 第八章:CLIP
简要介绍clip的原理和代码实现。
1241 0