Backbone | 谷歌提出LambdaNetworks:无需注意力让网络更快更强(文末获取论文源码)(二)

简介: Backbone | 谷歌提出LambdaNetworks:无需注意力让网络更快更强(文末获取论文源码)(二)

4. Lambda Layer


4.1 Context转换为线性函数

Lambda Layer将输入和Context C作为输入,并生成线性函数lambdas,然后应用于Query产生输出。不失一般性假设。和self-attention一样可能存在。

下图给出了lambda层的计算过程:

image.png

这里大致描述应用到single query的lambda layer:

1、生成与Context相关的Lambda函数:

这里希望生成一个到的线性函数,即一个矩阵。Lambda层首先通过线性预测Context和keys(通过Softmax规范化操作得到)在Context的位置,计算值和键。用标准化键和位置嵌入将集合得到矩阵如下:

image.png

式中,定义content lambda 和position lambda 为:

  • content lambda :content lambda在所有query位置上共享,并且对context元素的排列是不变的。它对如何仅基于context内容转换query 进行编码;
  • position lambda :position lambda通过位置嵌入依赖于query位置。它编码如何根据context元素及其相对位置转换query 到query 。

2、将Lambda函数用于query:

query 是通过学习线性投影从输入得到的,Lambda层的输出为:

image.png

3、Lambda层的解释:

矩阵的列可以看作是一个固定大小的context特征集。这些context特征基于context的content(content-based interactions)和结构(position-based interactions)进行聚合。然后应用lambda根据query动态分布这些context特征,以产生的输出。这个过程捕获内容和基于位置的互动,而不产生attention maps。

4、Normalization:

作者实验表明,在计算query和value后应用batch normalization是有帮助的。

4.2 减少复杂性的Multi-Query

由于输出维度比较大可能会带来比较大的计算复杂度,因此作者还设计了Multi-query Lambda Layer以减少复杂度,进而降低推理时间。

本文从输出维解耦lambda层的时间和空间复杂性。不是强加,而是创建 query ,对每个query 应用相同的,并将输出cat为。现在有,这将复杂度降低了的1倍。head的数目控制了的大小相对于query 的总大小。

def lambda layer(queries, keys, embeddings, values):
    """Multi−query lambda layer."""
    # b: batch, n: input length, m: context length,
    # k: query/key depth, v: value depth,
    # h: number of heads, d: output dimension.
    content lambda = einsum(softmax(keys), values, 'bmk,bmv−>bkv')
    position lambdas = einsum(embeddings, values, 'nmk,bmv−>bnkv')
    content output = einsum(queries, content lambda, 'bhnk,bkv−>bnhv')
    position output = einsum(queries, position lambdas, 'bhnk,bnkv−>bnhv')
    output = reshape(content output + position output, [b, n, d])
    return output

虽然这类似于multi-head或multi-query attention formulation,但动机是不同的。在attention操作中使用多个query增加了表征能力和复杂性。相反,在lambda层中使用多个query降低了复杂性和表示能力(忽略额外的query)。

最后,作者指出,可以将Multi-Lambda Layer扩展到linear attention,可以将其视为只包含content的lambda层。

4.3 使Lambda层translation equivariance

使用相对位置嵌入可以对context的结构做出明确的假设。特别地,在许多学习场景中,translation equivariance(即移动输入导致输出等效移动的特性)是一种强烈的归纳bias。

本文通过确保嵌入的位置满足对任何translation  获得translation equivariance位置的相互作用。在实践中,定义了一个张量的相对位置嵌入, index r为可能的相对位置对。

4.4 Lambda卷积

尽管远距离相互作用有诸多好处,但在许多任务中,局部性仍然是一种强烈的感应偏向。使用全局context可能会被证明是noisy或computationally excessive。因此,就像local self-attention和卷积一样将位置交互的范围限制在query位置周围的一个局部邻域可能是有用的。这可以通过将context位置在所需范围之外的相对嵌入置零来实现。然而,这种策略对于的大值仍然是复杂的,因为计算仍然会发生——它们只是被置零。

在context排列在multidimensional grid的情况下,可以通过使用正则卷积从local contexts等效地计算位置lambda。作者把这个运算称为lambda卷积。n维的lambda卷积可以使用n-d与channel乘法器的深度卷积或卷积来实现,将维中的维视为额外的空间维。

image.png

由于计算现在被限制在局部范围内,lambda卷积的时间和内存复杂度与输入长度成线性关系。lambda卷积很容易与其他功能一起使用,比如dilation和striding,并在专门的硬件加速器上实现了优化。这与local self-attention的实现形成了鲜明的对比,后者需要具体化重叠查询和context块的特性补丁,增加了内存消耗和延迟。

# b: batch, n: input length, m: context length, r: scope size,
# k: query/key depth, v: value depth, h: number of heads, d: output dimension.
def compute position lambdas(embeddings, values, impl=’einsum’):
    if impl == ’einsum’: # embeddings shape: [n, m, k]
        position lambdas = einsum(embeddings, values, ’nmk,bmv−>bnkv’)
    else: # embeddings shape: [r, k]
        if impl == ’conv’:
            embeddings = reshape(embeddings, [r, 1, 1, k])
            values = reshape(values, [b, n, v, 1])
            position lambdas = conv2d(values, embeddings)
        elif impl == ’depthwise conv’:
            # Reshape and tile embeddings to [r, v, k] shape
            embeddings = reshape(embeddings, [r, 1, k])
            embeddings = tile(embeddings, [1, v, 1])
            position lambdas = depthwise conv1d(values, embeddings)
        # Transpose from shape [b, n, v, k] to shape [b, n, k, v]
        position lambdas = transpose(position lambdas, [0, 1, 3, 2])
    return position lambdas
def lambda layer(queries, keys, embeddings, values, impl=’einsum’):
    """Multi−query lambda layer."""
    content lambda = einsum(softmax(keys), values, ’bmk,bmv−>bkv’)
    position lambdas = compute position lambdas(embeddings, values, impl=impl)
    content output = einsum(queries, content lambda, ’bhnk,bkv−>bnhv’)
    position output = einsum(queries, position lambdas, ’bhnk,bnkv−>bnhv’)
    output = reshape(content output + position output, [b, n, d])
    return output


5 问题讨论


1、lambda层与attention操作相比如何?

Lambda层规模有利比较self-attention。使用self-attention的Vanilla Transformers有θ内存footprint,而LambdaNetworks有θ内存footprint。这使得lambda层能够以更高的分辨率和更大的批处理规模使用。

此外,lambda卷积的实现比它的local self-attention对等物更简单、更快。最后,ImageNet实验表明lambda层优于self-attention,证明了lambda层的好处不仅仅是提高了速度和可伸缩性。

2、lambda层与线性注意力机制有何不同?

Lambda层推广和扩展了线性注意力公式,以捕获基于位置的交互,这对于建模高度结构化的输入(如图像)至关重要的;

image.png

由于目标不是近似一个attention kernel,lambda层可以通过使用非线性和规范化进一步提升性能,

image.png

3、在视觉领域如何最好地使用lambda层?

与global或local attention相比,lambda层改进了可伸缩性、速度和实现的简易性,这使得它们成为可视化领域中使用的一个强有力的候选对象。消融实验表明,当优化速度-精度权衡时,lambda层在视觉架构的中分辨率和低分辨率阶段最有利。也可以设计完全依赖lambda层的架构,这样可以更有效地进行参数化处理。作者在附录A中讨给出了使用的意见。

4、lambda层的泛化性如何?

虽然这项工作主要集中在静态图像任务上,但作者注意到lambda层可以被实例化来建模各种结构上的交互,如图形、时间序列、空间格等。lambda层将在更多的模式中有帮助,包括多模态任务。作者在附录中对masked contexts和auto-regressive进行了讨论和复现。


6 实验


6.1 ImageNet分类

该实验主要是针对基于ResNet50的架构改进设计和实验,通过下表可以看出在ResNet50的基础上提升还是比较明显的:

通过下表可以看出Lambda层可以捕捉在高分辨率图像上的全局交互,并获得1.0%的提升,且相较于local self-attention速度提升接近3倍。此外,位置嵌入可以跨lambda层共享,以最小的退化成本进一步减少内存需求。最后,lambda卷积具有线性的内存复杂度,这对于检测或分割中看到的非常大的图像是实用的。

image.png

通过下表可以看出,基于EfficientNet的改进在不损失精度的情况下,可以将训练速度提升9倍之多,推理速度提升6倍之多。

6.2 COCO目标检测实验

在下表中,作者在COCO目标检测和实例分割任务评估了LambdaResNets作为Mask-RCNN的Backbone的性能。使用lambda层可以在所有目标大小上产生一致的增益,特别是那些最难定位的小对象。这表明lambda层对于需要local信息的更复杂的视觉任务也具有竞争力。


7 参考


[1].LAMBDANETWORKS: MODELING LONG-RANGE INTERACTIONS WITHOUT ATTENTION

相关文章
|
14天前
|
机器学习/深度学习 人工智能
类人神经网络再进一步!DeepMind最新50页论文提出AligNet框架:用层次化视觉概念对齐人类
【10月更文挑战第18天】这篇论文提出了一种名为AligNet的框架,旨在通过将人类知识注入神经网络来解决其与人类认知的不匹配问题。AligNet通过训练教师模型模仿人类判断,并将人类化的结构和知识转移至预训练的视觉模型中,从而提高模型在多种任务上的泛化能力和稳健性。实验结果表明,人类对齐的模型在相似性任务和出分布情况下表现更佳。
38 3
|
27天前
|
机器学习/深度学习 数据可视化 测试技术
YOLO11实战:新颖的多尺度卷积注意力(MSCA)加在网络不同位置的涨点情况 | 创新点如何在自己数据集上高效涨点,解决不涨点掉点等问题
本文探讨了创新点在自定义数据集上表现不稳定的问题,分析了不同数据集和网络位置对创新效果的影响。通过在YOLO11的不同位置引入MSCAAttention模块,展示了三种不同的改进方案及其效果。实验结果显示,改进方案在mAP50指标上分别提升了至0.788、0.792和0.775。建议多尝试不同配置,找到最适合特定数据集的解决方案。
201 0
|
25天前
|
机器学习/深度学习 Web App开发 人工智能
轻量级网络论文精度笔(一):《Micro-YOLO: Exploring Efficient Methods to Compress CNN based Object Detection Model》
《Micro-YOLO: Exploring Efficient Methods to Compress CNN based Object Detection Model》这篇论文提出了一种基于YOLOv3-Tiny的轻量级目标检测模型Micro-YOLO,通过渐进式通道剪枝和轻量级卷积层,显著减少了参数数量和计算成本,同时保持了较高的检测性能。
31 2
轻量级网络论文精度笔(一):《Micro-YOLO: Exploring Efficient Methods to Compress CNN based Object Detection Model》
|
25天前
|
机器学习/深度学习 编解码 算法
轻量级网络论文精度笔记(三):《Searching for MobileNetV3》
MobileNetV3是谷歌为移动设备优化的神经网络模型,通过神经架构搜索和新设计计算块提升效率和精度。它引入了h-swish激活函数和高效的分割解码器LR-ASPP,实现了移动端分类、检测和分割的最新SOTA成果。大模型在ImageNet分类上比MobileNetV2更准确,延迟降低20%;小模型准确度提升,延迟相当。
51 1
轻量级网络论文精度笔记(三):《Searching for MobileNetV3》
|
25天前
|
编解码 人工智能 文件存储
轻量级网络论文精度笔记(二):《YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object ..》
YOLOv7是一种新的实时目标检测器,通过引入可训练的免费技术包和优化的网络架构,显著提高了检测精度,同时减少了参数和计算量。该研究还提出了新的模型重参数化和标签分配策略,有效提升了模型性能。实验结果显示,YOLOv7在速度和准确性上超越了其他目标检测器。
43 0
轻量级网络论文精度笔记(二):《YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object ..》
|
3月前
|
机器学习/深度学习 算法 网络架构
神经网络架构殊途同归?ICML 2024论文:模型不同,但学习内容相同
【8月更文挑战第3天】《神经语言模型的缩放定律》由OpenAI研究人员完成并在ICML 2024发表。研究揭示了模型性能与大小、数据集及计算资源间的幂律关系,表明增大任一资源均可预测地提升性能。此外,论文指出模型宽度与深度对性能影响较小,较大模型在更多数据上训练能更好泛化,且能高效利用计算资源。研究提供了训练策略建议,对于神经语言模型优化意义重大,但也存在局限性,需进一步探索。论文链接:[https://arxiv.org/abs/2001.08361]。
43 1
|
4月前
|
机器学习/深度学习 计算机视觉
【YOLOv8改进 - 注意力机制】Gather-Excite : 提高网络捕获长距离特征交互的能力
【YOLOv8改进 - 注意力机制】Gather-Excite : 提高网络捕获长距离特征交互的能力
|
4月前
|
机器学习/深度学习 编解码 计算机视觉
【YOLOv8改进- Backbone主干】BoTNet:基于Transformer,结合自注意力机制和卷积神经网络的骨干网络
【YOLOv8改进- Backbone主干】BoTNet:基于Transformer,结合自注意力机制和卷积神经网络的骨干网络
|
3月前
|
人工智能 算法 安全
【2023 年第十三届 MathorCup 高校数学建模挑战赛】C 题 电商物流网络包裹应急调运与结构优化问题 赛后总结之31页论文及代码
本文总结了2023年第十三届MathorCup高校数学建模挑战赛C题的解题过程,详细阐述了电商物流网络在面临突发事件时的包裹应急调运与结构优化问题,提出了基于时间序列预测、多目标优化、遗传算法和重要性评价模型的综合解决方案,并提供了相应的31页论文和代码实现。
73 0
|
4月前
|
存储 Java Unix
(八)Java网络编程之IO模型篇-内核Select、Poll、Epoll多路复用函数源码深度历险!
select/poll、epoll这些词汇相信诸位都不陌生,因为在Redis/Nginx/Netty等一些高性能技术栈的底层原理中,大家应该都见过它们的身影,接下来重点讲解这块内容。