最快ViT | FaceBook提出LeViT,0.077ms的单图处理速度却拥有ResNet50的精度(文末附论文与源码)(一)

简介: 最快ViT | FaceBook提出LeViT,0.077ms的单图处理速度却拥有ResNet50的精度(文末附论文与源码)(一)

1 简介


本文的工作利用了基于注意力体系结构中的最新发现,该体系结构在高度并行处理硬件上具有竞争力。作者从卷积神经网络的大量文献中重新评估了原理,以将其应用于Transformer,尤其是分辨率降低的激活图。同时作者还介绍了Attention bias,一种将位置信息集成到视觉Transformer中的新方法。

图1 LeViT性能对比

最终作者提出了LeVIT:一种用于快速推理的混合神经网络。考虑在不同的硬件平台上采用不同的效率衡量标准,以最好地反映各种应用场景。作者通过广泛的实验表明该方法适用于大多数体系结构。总体而言,在速度/准确性的权衡方面,LeViT明显优于现有的卷积网络和视觉Transformer。例如,在ImageNet Top-1精度为80%的情况下,LeViT比CPU上的EfficientNet快3.3倍。

相同计算复杂度的情况下Transformer为什么快?

大多数硬件加速器(gpu,TPUs)被优化以用来执行大型矩阵乘法。在Transformer中,注意力机制和MLP块主要依靠这些操作。相比之下,卷积需要复杂的数据访问模式,因此它们的操作通常受io约束。这些考虑对于我们探索速度/精度的权衡是很重要的。

本文主要贡献:

  1. 采用注意力机制作为下采样机制的multi-stage transformer 结构;
  2. 一种计算效率高的patch descriptor,可以减少第一层特征的数量;
  3. 使用Translation-invariant attention bias取代ViT中的位置嵌入;
  4. 为了提高给定计算时间的网络容量,作者重新设计了Attention-MLP Block。

2 LeViT的设计


2.1 LeViT设计原则

LeViT以ViT的架构和DeiT的训练方法为基础,合并了对卷积架构有用的组件。第1步是获得Compatible Representation。如果不考虑classification embedding的作用,ViT就是一个处理激活映射的Layer的堆叠。

image.png

实际上,中间“Token”嵌入可以看作是FCN体系结构中传统的C×H×W激活映射(BCHW格式)。因此,适用于激活映射(池、卷积)的操作可以应用于DeiT的中间表征。

LeViT优化了计算体系结构,不一定是为了最小化参数的数量。ResNet系列比VGG更高效的设计原则之一是在其前2个阶段使用相对较小的计算预算应用strong resolution reductions。当激活映射到达ResNet的第3阶段时,其分辨率已经缩小到足以将卷积应用于小的激活映射,从而降低了计算成本。

2.2 LeViT组件

1、Patch embedding

初步分析表明,在transformer组的输入上应用一个小卷积可以提高精度。因此在LeViT中作者选择对输入应用4层3×3卷积(stride2)来降低分辨率。channel的数量是C=3,32,64,128,256。

以上操作减少了对transformer下层的激活映射的输入,同时不丢失重要信息。LeViT-256的patch extractor用184 MFLOPs将图像形状(3,224,224)转换为(256,14,14)。作为比较,ResNet-18的前10层使用1042 MFLOPs执行相同的dimensionality reduction。

为什么在transformer组的输入上应用一个小卷积可以提高精度?

image.png

2、No classification token

为了使用BCHW张量形式,LeViT删除了classification token。类似于卷积网络,在最后一个激活映射上使用GAP来代替,这将产生一个用于分类器的embedding。在训练中进行蒸馏,作者分别训练分类和蒸馏的Head。在测试时,平均2个Head的输出。在实践中,LeViT可以使用BNC或BCHW张量格式。

3、Normalization layers and activations

ViT架构中的FC层相当于1x1卷积。ViT在每个注意点和MLP单元之前使用层归一化。对于LeViT,每次卷积之后都要进行BN操作。然后与residual connection连接起来的每个BN权重参数初始化为零。BN可以与之前的卷积合并来进行推理,这比层归一化有运行优势(例如,在EfficientNet B0上,这种融合将GPU的推理速度提高了2倍)。而DeiT使用GELU函数,而LeViT的非线性激活都是Hardswish。

class Linear_BN(torch.nn.Sequential):
    def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
        super().__init__()
        self.add_module('c', torch.nn.Linear(a, b, bias=False))
        bn = torch.nn.BatchNorm1d(b)
        torch.nn.init.constant_(bn.weight, bn_weight_init)
        torch.nn.init.constant_(bn.bias, 0)
        self.add_module('bn', bn)
        global FLOPS_COUNTER
        output_points = resolution**2
        FLOPS_COUNTER += a * b * output_points
    @torch.no_grad()
    def fuse(self):
        l, bn = self._modules.values()
        w = bn.weight / (bn.running_var + bn.eps)**0.5
        w = l.weight * w[:, None]
        b = bn.bias - bn.running_mean * bn.weight / \
            (bn.running_var + bn.eps)**0.5
        m = torch.nn.Linear(w.size(1), w.size(0))
        m.weight.data.copy_(w)
        m.bias.data.copy_(b)
        return m
    def forward(self, x):
        l, bn = self._modules.values()
        x = l(x)
        return bn(x.flatten(0, 1)).reshape_as(x)

4、Multi-resolution pyramid

LeViT在transformer架构中集成了ResNet stage。在各个stage中,该体系结构类似于一个visual transformer:一个带有交替MLP和激活块的残差模块。下面是注意块的修改。

class Attention(torch.nn.Module):
    def __init__(self, dim, key_dim, num_heads=8,
                 attn_ratio=4,
                 activation=None,
                 resolution=14):
        super().__init__()
        self.num_heads = num_heads
        self.scale = key_dim ** -0.5
        self.key_dim = key_dim
        self.nh_kd = nh_kd = key_dim * num_heads
        self.d = int(attn_ratio * key_dim)
        self.dh = int(attn_ratio * key_dim) * num_heads
        self.attn_ratio = attn_ratio
        h = self.dh + nh_kd * 2
        self.qkv = Linear_BN(dim, h, resolution=resolution)
        self.proj = torch.nn.Sequential(activation(), Linear_BN(
            self.dh, dim, bn_weight_init=0, resolution=resolution))
        points = list(itertools.product(range(resolution), range(resolution)))
        N = len(points)
        attention_offsets = {}
        idxs = []
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = torch.nn.Parameter(
            torch.zeros(num_heads, len(attention_offsets)))
        self.register_buffer('attention_bias_idxs',
                             torch.LongTensor(idxs).view(N, N))
        global FLOPS_COUNTER
        #queries * keys
        FLOPS_COUNTER += num_heads * (resolution**4) * key_dim
        # softmax
        FLOPS_COUNTER += num_heads * (resolution**4)
        #attention * v
        FLOPS_COUNTER += num_heads * self.d * (resolution**4)
    @torch.no_grad()
    def train(self, mode=True):
        super().train(mode)
        if mode and hasattr(self, 'ab'):
            del self.ab
        else:
            self.ab = self.attention_biases[:, self.attention_bias_idxs]
    def forward(self, x):  # x (B,N,C)
        B, N, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.view(B, N, self.num_heads, -
                           1).split([self.key_dim, self.key_dim, self.d], dim=3)
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)
        attn = (
            (q @ k.transpose(-2, -1)) * self.scale
            +
            (self.attention_biases[:, self.attention_bias_idxs]
             if self.training else self.ab)
        )
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
        x = self.proj(x)
        return x

5、Downsampling

在LeViT stage之间,一个缩小的注意块减少了激活映射的大小:在Q转换之前应用一个subsampling,然后传播到soft activation的输出。这将一个大小为的输入张量映射到一个大小为的输出张量。由于尺度的变化这个注意块的使用没有残差连接。同时为了防止信息丢失,这里将注意力头的数量设为。

class Subsample(torch.nn.Module):
    def __init__(self, stride, resolution):
        super().__init__()
        self.stride = stride
        self.resolution = resolution
    def forward(self, x):
        B, N, C = x.shape
        x = x.view(B, self.resolution, self.resolution, C)[
            :, ::self.stride, ::self.stride].reshape(B, -1, C)
        return x

6、Attention bias instead of a positional embedding

在transformer架构中的位置嵌入是一个位置依赖可训练的向量,在将token嵌入输入到transformer块之前,将其添加到token嵌入。如果没有它,转换器输出将独立于输入标记的排列。位置嵌入的Ablations会导致分类精度的急剧下降。

然而,位置嵌入只包含在注意块序列的输入上。因此,由于位置编码对higher layer也很重要,所以它很可能仍然处于中间表示中。

因此,LeViT在每个注意块中提供位置信息,并在注意机制中明确地注入相对位置信息:只是在注意力图中添加了注意偏向。对于每个head ,每2个像素和之间的标量值计算方式为:

image.png

第一项是经典的注意力。第二个是translation-invariant attention bias。每个Head有H×W参数对应不同的像素偏移量。对称差异和鼓励用 flip invariance进行训练。

self.attention_biases = torch.nn.Parameter(
            torch.zeros(num_heads, len(attention_offsets)))

7、Smaller keys

由于translation-invariant attention bias偏置项减少了key对位置信息编码的压力,因此LeViT减少了key矩阵相对于V矩阵的大小。如果key大小为, V则有2D通道。key的大小可以减少计算key product 所需的时间。

对于没有残差连接的下采样层,将V的维数设置为4D,以防止信息丢失。

8、Attention activation

在使用常规线性投影组合不同Heads的输出之前,对product 应用Hardswish激活。这类似于ResNet bottleneck residual block,V是一个1×1卷积的输出,对应一个spatial卷积,projection是另一个1×1卷积。

9、Reducing the MLP blocks

在ViT中,MLP residual块是一个线性层,它将嵌入维数增加了4倍,然后用一个非线性将其减小到原来的嵌入维数。但是对于视觉架构,MLP通常在运行时间和参数方面比注意Block更昂贵。

对于LeViT, MLP是1x1卷积,然后是通常的BN。为了减少计算开销,将卷积的展开因子从4降低到2。一个设计目标是注意力和MLP块消耗大约相同数量的FLOPs

2.3 LeViT家族

相关文章
|
计算机视觉
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(二)
迟到的 HRViT | Facebook提出多尺度高分辨率ViT,这才是原汁原味的HRNet思想(二)
273 0
|
机器学习/深度学习 算法 计算机视觉
经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现)
经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现)
3838 1
经典神经网络论文超详细解读(五)——ResNet(残差网络)学习笔记(翻译+精读+代码复现)
|
7月前
|
机器学习/深度学习 编解码 自然语言处理
南开提出全新ViT | Focal ViT融会贯通Gabor滤波器,实现ResNet18相同参数,精度超8.6%
南开提出全新ViT | Focal ViT融会贯通Gabor滤波器,实现ResNet18相同参数,精度超8.6%
221 0
|
编解码 测试技术 计算机视觉
LVT | ViT轻量化的曙光,完美超越MobileNet和ResNet系列(二)
LVT | ViT轻量化的曙光,完美超越MobileNet和ResNet系列(二)
251 0
LVT | ViT轻量化的曙光,完美超越MobileNet和ResNet系列(二)
|
机器学习/深度学习 编解码 计算机视觉
ResNet50 文艺复兴 | ViT 原作者让 ResNet50 精度达到82.8%,完美起飞!!!(二)
ResNet50 文艺复兴 | ViT 原作者让 ResNet50 精度达到82.8%,完美起飞!!!(二)
187 0
|
编解码 TensorFlow 算法框架/工具
ResNet50 文艺复兴 | ViT 原作者让 ResNet50 精度达到82.8%,完美起飞!!!(一)
ResNet50 文艺复兴 | ViT 原作者让 ResNet50 精度达到82.8%,完美起飞!!!(一)
172 0
|
机器学习/深度学习 编解码 数据可视化
超越 Swin、ConvNeXt | Facebook提出Neighborhood Attention Transformer
超越 Swin、ConvNeXt | Facebook提出Neighborhood Attention Transformer
171 0
|
机器学习/深度学习 vr&ar 计算机视觉
ShiftViT用Swin Transformer的精度跑赢ResNet的速度,论述ViT的成功不在注意力!(二)
ShiftViT用Swin Transformer的精度跑赢ResNet的速度,论述ViT的成功不在注意力!(二)
234 0
|
机器学习/深度学习 自然语言处理 算法
ShiftViT用Swin Transformer的精度跑赢ResNet的速度,论述ViT的成功不在注意力!(一)
ShiftViT用Swin Transformer的精度跑赢ResNet的速度,论述ViT的成功不在注意力!(一)
235 0
|
机器学习/深度学习 编解码 自然语言处理
LVT | ViT轻量化的曙光,完美超越MobileNet和ResNet系列(一)
LVT | ViT轻量化的曙光,完美超越MobileNet和ResNet系列(一)
410 0

热门文章

最新文章