ICLR 2023 | DIFFormer: 扩散过程启发的Transformer(2)

简介: ICLR 2023 | DIFFormer: 扩散过程启发的Transformer

定理
对于任意的由 (2) 式所定义的能量函数,存在步⻓和相应的扩散率估计


使得由 (1) 式定义的扩散⽅程数值迭代保证每⼀步的能量下降,即

基于这⼀理论结果,我们进⽽提出了扩散过程诱导下的 Transformer 结构,即 DIFFormer,它的每⼀层更新公式表示为:


这⾥的表示衡量相似性的函数,在具体设计时具有很⼤的灵活性。下⾯我们提出两种具体设计,分别称相应的模型结构为 DIFFormer-s 和 DIFFormer-a。


  • DIFFormer-s:采⽤简单的 dot-product 来衡量相似性,作为 attention function(这⾥使⽤ L2 normalization 将输⼊向量限制在 [-1,1] 之间从⽽保证得到的注意⼒权重⾮负):



  • DIFFormer-a:在计算相似度时引⼊⾮线性,从⽽提升模型学习复杂结构的表达能⼒:



当我们考虑每层两两节点之间的全局 attention,⼀个潜在的问题是 all-pair attention 带来的平⽅复杂度。庆幸的是,这⾥ DIFFormer-s 的 attention 定义可以保证每⼀层更新个样本表征的计算复杂度在之内,这⾮常有利于提升模型的时空效率(特别是空间效率,当需要扩展到包含⼤量样本的数据集时)。


为什么能实现复杂度呢
我们可以把
代⼊更新单个样本的聚合公式,然后通过矩阵乘法结合律交换矩阵运算的顺序(这⾥假设):


在上式左边的式⼦中,计算⼀次需要复杂度,⽽⼜因为这是对单个样本的更新公式,因此更新个不同的样本需要的复杂度是。但在右边的式⼦中,分⼦和分⺟的两个求和项对于所有样本是共享的,也就是说在实际计算中只需要算⼀次,⽽后对每个样本的更新只需要,因此更新个样本的总复杂度是。不过对于 DIFFormer-a 的 attention 设计,则⽆法保证的计算复杂度,因为⾮线性的引⼊导致了⽆法交换矩阵运输的次序。下图总结了两个模型在具体实现(采⽤矩阵乘法更新⼀层所有样本的表征)中的运算过程。


两种模型 DIFFormer-s 和 DIFFormer-a 每层更新的运算过程(矩阵形式),红⾊标注的矩阵乘法操作是计算瓶颈。DIFFormer-s 的优势在于可以实现对样本数量 N 的线性复杂度,有利于模型扩展到⼤规模数据集


模型扩展
更进⼀步的,我们可以引⼊更多设计来提升模型的适⽤性和灵活度。上述的模型主要考虑了样本间的 all-pair attention。对于输⼊数据本身就含有样本间图结构的情况,我们可以加⼊现有图神经⽹络(GNN)中常⽤的传播矩阵(propagation matrix)来融合已知的图结构信息,从⽽定义每层的样本表征更新如下


⽐如如果采⽤图卷积⽹络(GCN)中的传播矩阵,则这⾥表示输⼊图,表示其对应的(对⻆)度矩阵。


类似其他 Transformer ⼀样,在每层更新中我们可以加⼊ residual link,layer normalization,以及⾮线性激活。下图展示了 DIFFormer 的单层更新过程。

DIFFormer 的全局输⼊包含样本输⼊特征 X 以及可能存在的图结构 A(可以省略),通过堆叠 DIFFormer layer 更新计算样本表征。在每层更新时,需要计算⼀个全局 attention(具体的可以使⽤ DIFFormer-s 和 DIFFormer-a 两种实现),如果考虑输⼊图结构则加⼊ GCN Conv


另⼀个值得探讨的问题,是如何处理⼤规模数据集(尤其是包含⼤量样本的数据集,此时考虑全局 all-pair attention ⾮常耗费资源)。在这种情况下我们默认使⽤线性复杂度的 DIFFormer-s 的架构,并且可以在每个训练 epoch 对数据集进⾏ random mini-batch 划分。由于线性复杂度,我们可以使⽤较⼤的 batch size 也能使得模型在单卡上进⾏训练(详⻅实验部分)。

对于包含⼤量样本的数据集,我们可以对样本进⾏随机 minibatch 划分,每次只输⼊⼀个 batch 的样本。当输⼊包含图结构时,我们可以只提取 batch 内部样本所组成的⼦图输⼊进⽹络。由于 DIFFormer-s 只需要对 batch size 的线性复杂度,在实际中就可以使⽤较⼤的 batch size,保证充⾜的全局信息

实验结果

为了验证 DIFFormer 的有效性和在不同场景下的适⽤性,我们考虑了多个实验场景,包括不同规模图上的节点分类、半监督图⽚ / ⽂本分类和时空预测任务。

图节点分类实验
此时输⼊数据是⼀张图,图中的每个节点是⼀个样本(包含特征和标签),⽬标是利⽤节点特征和图结构来预测节点的标签。我们⾸先考虑⼩规模图 的实验,此时可以将⼀整图输⼊ DIFFormer。相⽐于同类模型例如 GNN,DIFFormer 的优势在于可以不受限于输⼊图,学习未被观测到的连边关系,从⽽更好的捕捉⻓距离依赖和潜在关系。下图展示了与 SOTA ⽅法的对⽐结果。


进⼀步的我们考虑在⼤规模图上的实验。此时由于图的规模过⼤,⽆法将⼀整图直接输⼊模型(否则将造成 GPU 过载),我们使⽤ mini-batch 训练。具体的,在每个 epoch,随机的将所有节点分为相同⼤⼩的 mini-batch。每次只将⼀个 mini-batch 的节点输⼊进⽹络;⽽对于输⼊图,只使⽤包含在这个 mini-batch 内部的节点所组成的⼦图输⼊进⽹络;每次迭代过程中,DIFFormer 也只会在 mini-batch 内部的节点之间学习 all-pair attention。这样做就能⼤⼤减⼩空间消耗。⼜因为 DIFFormer-s 的计算复杂度关于 batch size 是线性的,这就允许我们使⽤很⼤的 batch size 进⾏训练。下图显示了在 ogbn-proteins 和 pokec 两个⼤图数据集上的测试性能,其中对于 proteins/pokec 我们分别使⽤了 10K/100K 的 batch size。此外,下图的表格也展示了 batch size 对模型性能的影响,可以看到,当使⽤较⼤ batch size 时,模型性能是⾮常稳定的。


图⽚ / ⽂本分类实验
第⼆个场景我们考虑⼀般的分类问题,输⼊是⼀些独⽴的样本(如图⽚、⽂本),样本间没有已观测到的依赖关系。此时尽管没有输⼊图结构, DIFFormer 仍然可以学习隐含在数据中的样本依赖关系。对于对⽐⽅法 GCN/GAT,由于依赖于输⼊图,我们这⾥使⽤ K 近邻⼈⼯构造⼀个样本间的图结构。

相关文章
|
机器学习/深度学习 算法 PyTorch
论文阅读笔记 | 目标检测算法——DETR
论文阅读笔记 | 目标检测算法——DETR
902 0
论文阅读笔记 | 目标检测算法——DETR
|
2月前
|
机器学习/深度学习 搜索推荐
CIKM 2024:LLM蒸馏到GNN,性能提升6.2%!Emory提出大模型蒸馏到文本图
【9月更文挑战第17天】在CIKM 2024会议上,Emory大学的研究人员提出了一种创新框架,将大型语言模型(LLM)的知识蒸馏到图神经网络(GNN)中,以克服文本图(TAGs)学习中的数据稀缺问题。该方法通过LLM生成文本推理,并训练解释器模型理解这些推理,再用学生模型模仿此过程。实验显示,在四个数据集上性能平均提升了6.2%,但依赖于LLM的质量和高性能。论文链接:https://arxiv.org/pdf/2402.12022
77 7
|
3月前
|
机器学习/深度学习 算法 网络架构
神经网络架构殊途同归?ICML 2024论文:模型不同,但学习内容相同
【8月更文挑战第3天】《神经语言模型的缩放定律》由OpenAI研究人员完成并在ICML 2024发表。研究揭示了模型性能与大小、数据集及计算资源间的幂律关系,表明增大任一资源均可预测地提升性能。此外,论文指出模型宽度与深度对性能影响较小,较大模型在更多数据上训练能更好泛化,且能高效利用计算资源。研究提供了训练策略建议,对于神经语言模型优化意义重大,但也存在局限性,需进一步探索。论文链接:[https://arxiv.org/abs/2001.08361]。
47 1
|
5月前
|
机器学习/深度学习 自然语言处理
解决Transformer根本缺陷,CoPE论文爆火:所有大模型都能获得巨大改进
【6月更文挑战第9天】CoPE论文提出了一种新方法,解决Transformer模型位置处理缺陷,通过上下文依赖的位置编码增强序列元素识别,改进选择性复制、计数等任务,提升语言建模和编码任务的困惑度。但CoPE增加模型复杂性,可能受模型大小和数据量限制,且过度依赖上下文可能引入偏见。[https://arxiv.org/pdf/2405.18719]
63 6
|
6月前
|
机器学习/深度学习 算法
ICLR 2024 Oral:用巧妙的传送技巧,让神经网络的训练更加高效
【5月更文挑战第21天】ICLR 2024 Oral 提出了一种名为“传送”的新方法,利用参数对称性提升神经网络训练效率。该方法通过参数变换加速收敛,改善泛化能力,减少了训练所需的计算资源和时间。研究显示,传送能将模型移到不同曲率的极小值点,可能有助于泛化。论文还探讨了将传送应用于元学习等优化算法的潜力,但对传送加速优化的确切机制理解尚不深入,且实际应用效果有待更多验证。[论文链接](https://openreview.net/forum?id=L0r0GphlIL)
65 2
|
6月前
|
机器学习/深度学习 编解码 算法
助力目标检测涨点 | 可以这样把Vision Transformer知识蒸馏到CNN模型之中
助力目标检测涨点 | 可以这样把Vision Transformer知识蒸馏到CNN模型之中
238 0
|
机器学习/深度学习 计算机视觉
深度学习原理篇 第七章:Deformable DETR
简要介绍Deformable DETR的原理和代码实现。
1430 1
|
机器学习/深度学习 编解码 自然语言处理
深度学习进阶篇[9]:对抗生成网络GANs综述、代表变体模型、训练策略、GAN在计算机视觉应用和常见数据集介绍,以及前沿问题解决
深度学习进阶篇[9]:对抗生成网络GANs综述、代表变体模型、训练策略、GAN在计算机视觉应用和常见数据集介绍,以及前沿问题解决
深度学习进阶篇[9]:对抗生成网络GANs综述、代表变体模型、训练策略、GAN在计算机视觉应用和常见数据集介绍,以及前沿问题解决
|
机器学习/深度学习 人工智能 自然语言处理
【ICLR2020】基于模型的强化学习算法玩Atari【附代码】
【ICLR2020】基于模型的强化学习算法玩Atari【附代码】
188 0
|
机器学习/深度学习 数据挖掘
ICLR 2023 | DIFFormer: 扩散过程启发的Transformer(3)
ICLR 2023 | DIFFormer: 扩散过程启发的Transformer
188 0