无惧大规模GNN,用子图也一样!中科大提出首个可证明收敛的子图采样方法 | ICLR 2023 Spotlight

简介: 无惧大规模GNN,用子图也一样!中科大提出首个可证明收敛的子图采样方法 | ICLR 2023 Spotlight


 新智元报道  

编辑:好困

【新智元导读】中科大王杰教授团队提出局部消息补偿技术,解决采样子图边缘节点邻居缺失问题,弥补图神经网络(GNNs)子图采样方法缺少收敛性证明的空白,推动 GNNs 的可靠落地。


图神经网络(Graph Neural Networks,简称 GNNs)是处理图结构数据的最有效的机器学习模型之一,也是顶会论文的香饽饽。然而,GNNs 的计算效率一直是个硬伤,在大规模图数据上训练 GNNs 常常会遇上邻居爆炸(neighbor explosion)问题——节点表示和随机梯度的计算复杂度会随着图神经网络层数的增加而指数上升。

很多 GNNs 的学术研究都会倾向于选择小规模图数据集(千量级节点数)进行实验,避开 GNNs 的计算效率问题。但是,这一问题在工业界实际落地的场景中无法避免:在大规模图数据(十亿节点)[3] 上,这些 GNNs 根本无法运行。

一个最简单粗暴的办法是:在每次模型训练或预测的时候,从全量图上切出一个子图,在子图上运行 GNNs。这又会带新的问题:在子图上训练的 GNNs 能和全量图上训练的 GNNs 一样吗?子图边缘节点会不会丢失很多邻居信息?

为此,中科大 MIRA Lab 王杰教授团队提出了一种 GNNs 的子图采样训练方法——本地消息补偿(Local Message Compensation,简称 LMC)。

LMC 具有极低的计算开销;并且,理论证明:LMC 在子图上训练的 GNNs 的性能可媲美在全量图上训练的 GNNs,同时 LMC 能加速 GNNs 收敛。相关成果论文已被 ICLR 2023 接收为 Spotlight。

作者列表:石志皓,梁锡泽,王杰

论文链接:https://openreview.net/forum?id=5VBBA91N6n

1. 引言


基于消息传递机制的图神经网络(GNNs)在许多实际应用中取得了巨大成功。然而,在大规模图上训练 GNNs 会遇到众所周知的邻居爆炸(neighbor explosion)问题——节点的依赖性随消息传递层的数量呈指数增长。

子图采样方法——一类备受瞩目的小批量训练(mini-batch training)技术——在反向传播中丢弃小批量之外的消息,以此避免邻居爆炸问题,但同时以牺牲梯度估计的精度为代价。这对它们的收敛性分析收敛速度都提出了重大挑战,严重限制了它们在现实场景中的进一步应用。为了应对这些挑战,我们提出了一种具有收敛性保证的新型子图采样方法——本地消息补偿(Local Message Passing,简称 LMC)。据我们所知,LMC 是首个具有可证明收敛性的子图采样方法。LMC 的关键思想是基于反向传播传递的消息传递建模来恢复在反向传播中被丢弃的消息。通过对正向和反向传播中丢弃的消息进行高效和有效的补偿,LMC 计算出准确的小批量梯度,从而加速收敛。进一步地,我们证明了 LMC 收敛到 GNNs 的一阶驻点(first-order stationary points)。在大规模基准测试任务中的实验表明,LMC 在效率方面明显优于最先进的子图采样方法。2. 背景与问题


2.1 图神经网络

在实际问题中,图结构数据随处可见,例如知识图谱、分子、计算机网络、社交网络、神经元网络、文章引用网络等,如图1所示。

图1. 图结构数据在实际问题中随处可见,图中展示了各式各样的图数据。

图神经网络(Graph Neural Networks,简称 GNNs)通过消息传递范式 [1] 处理图数据,是当前处理图结构数据最有效的机器学习模型之一。在每个消息传递层中,GNNs 迭代地聚合邻居节点的消息,以更新当前节点的表示。这种范式在许多实际应用中取得了巨大的成功,例如搜索引擎 [2]、推荐系统 [3]、材料工程 [4]、分子性质预测 [5, 6],以及组合优化 [7]。具体地,以半监督的结点分类任务为例,GNNs 旨在通过最小化目标函数  来学习结点嵌入  以及参数 ,其中 , 是有标签结点的集合, 是参数为  的输出层与一个损失函数的组合, 是结点  的嵌入, 是结点  的标签, 是结点特征, 是图上所有边的集合。一个  层的 GNN 通过  次有着不同参数  的消息传递迭代来生成最终的结点嵌入 :其中 , 是第  层的消息传递函数,参数为 。消息传递函数  遵循聚合-更新机制,即

其中  是为结点  的每个邻居生成消息的函数, 是将邻居消息集合映射到最终消息  的聚合函数, 是组合从前的结点嵌入 ,消息 ,以及结点特征  的更新函数。2.2 邻居爆炸尽管 GNNs 在许多应用中取得了巨大的成功,这种消息迭代机制也给 GNNs 在大规模图数据上的训练带来了挑战。使用有限的 GPU 内存将深度模型扩展到任意大规模数据的一种常见方法是通过小批量梯度近似全批次梯度。然而,对于图结构数据,由于众所周知的邻居爆炸问题,计算小批量节点的损失函数和相应的小批量梯度的成本是非常昂贵的。具体地,对于 GNNs 而言,一个结点在第  层消息传递中的嵌入递归地依赖于它邻居在第  层的嵌入。因此,计算复杂度会随着消息传递层数的增加而指数级上涨,带来无法令人接受的计算开销。2.3 子图采样方法为了解决邻居爆炸问题,最近的一些工作提出了各种各样的采样技术以减少消息传递所牵涉的节点个数。例如,结点采样方法 [8, 9] 和层采样方法 [10, 11, 12] 会在消息传递中递归地采样邻居,从而估计结点嵌入以及对应的小批量梯度。与这种递归的范式不同,子图采样方法 [13, 14, 15, 16] 使用了一种更为简单、成本低廉的一次性采样范式(one-shot sampling fashion)——为不同的消息传递层采样同一个子图,该子图由同一小批量结点所构建。通过丢弃小批量之外的消息,子图采样方法将消息传递过程限制在小批量中,使得复杂度随消息传递层数的增加而线性增长,极大降低了计算开销。此外,通过直接在子图上运行 GNNs,子图采样方法适用于非常广泛的 GNN 结构。由于上述优势,子图采样方法近期收到了越来越多的关注。然而,子图采样方法这种丢弃小批量外部消息的做法牺牲了梯度估计的精度,这给它们的收敛性分析和收敛速度带来了极大挑战:

  • 首先,近期工作 [9, 17] 表明,不准确的小批量梯度会严重降低 GNNs 的收敛速度。
  • 其次,我们的实验表明,现有子图采样方法在批量大小较小时难以达到全梯度下训练的表现;而我们在实际应用中经常会将批量大小设置为一个较小的数字,以避免超出 GPU 的显存。

对此,我们发问:能否设计一个子图采样方法,它既有极低的计算开销,又有媲美全梯度训练的预测精度,同时还有严格的收敛性保证

我们的回答是:LMC 能做到!

3. 方法:局部信息补偿 LMC

我们的研究思路受到了 VR-GCN [9] 的启发,其主要抓手是把节点或层级别的递归采样看成一个无偏的基线方法——Standard SGD的近似,进而通过对于梯度的误差分析来证明收敛性。

然而,很难把子图采样方法看成 Standard SGD 的近似,因为子图采样在每一层都采样相同的子图,每一层的计算都会引入不可避免的偏差。因此,我们第一步是先提出一个 Backward SGD,它更便于我们的分析子图采样这种一次性采样范式。在 Backward SGD 的基础上,我们分析如何给子图采样方法加入合适的补偿项,减少它的偏差,进而找到一个可证明收敛的子图采样算法。3.1 将反向传播建模为消息传递梯度  是容易计算的,所以我们主要介绍如何计算 。令 , 为辅助变量,则有 。由链式法则,我们能够基于  迭代地计算 :

以及

然后,我们可以使用处理向量-Jacobian 积的自动求导工具来计算梯度 我们将反向传播(即迭代计算方程 (3) 的过程)建模为消息传递。为了看到这一点,我们只需注意到 (3) 等价于

其中  是  的第  列。方程 (5) 分别使用 、求和聚合,以及恒等映射作为生成函数、聚合函数,以及更新函数。

3.2 LMC 的基石:Backward SGD

基于这一反向传播的消息传递建模,我们设计了一个 SGD 变体——Backward SGD,它能带来无偏的梯度估计。需要说明的是,Backward SGD 是我们主要方法 LMC 的基石给定一个被采样的小批量 ,假设我们已经获得了小批量中结点的准确嵌入  和准确辅助变量 。读者需要注意:这一假设是比较强的,事实上我们很难计算准确的嵌入和辅助变量。下一节介绍的主要方法 LMC 所做的就是利用子图采样来估计结点嵌入和辅助变量。首先,Backward SGD 计算参数  的小批量梯度 :

然后,Backward SGD 计算参数  的小批量梯度 :

注意到:对于不同的层数 ,小批量梯度  牵涉的小批量是同一个(即 ),这就给基于 Backward SGD 设计子图采样方法提供了基础。Backward SGD 的另一个吸引人的性质是:小批量梯度  和  是无偏的,如第4节中的定理1所示。详细的证明请参见原论文附录。3.3 本地消息补偿 LMC在上一节中,Backward SGD 所计算的小批量梯度依赖于小批量中节点的准确嵌入和准确辅助变量,而不是整张图。然而,Backward SGD 仍然不是可扩展的(scalable),因为邻居爆炸问题会使得准确结点嵌入和辅助变量的计算极其昂贵,所以事实上我们无法得到准确的  和 。在这一节中,为了解决邻居爆炸问题,我们提出了新颖的子图采样方法——本地消息补偿(Local Message Compensation,简称 LMC)。LMC 首先通过不完全最新值(incomplete up-to-date values)和历史值(historical values)的凸组合来高效地估计  和 ,然后利用方程 (6) 和 (7) 计算小批量梯度。在之后的理论分析中,我们证明了 LMC 收敛到 GNNs 的一阶驻点。在算法1和理论分析中,我们用  表示一个第  层、第  次迭代时的量,而在其他地方我们省略上标 ,用  来表示。在每个训练迭代中,我们采样一个小批量结点 ,通过历史值  和 ,以及不完全最新值 和  的凸组合来高效地估计  和 。

为便于读者理解方法的核心思想,我们将 LMC 与现有最先进方法 GAS [15] 的前向传播、反向传播计算图展示在图2。

图2. LMC 与 GAS 前向传播与反向传播的计算图。可以看到,在前向传播和反向传播中,LMC 均进行了小批量结点与一跳邻居之间的消息交互(即补偿),而 GAS 在反向传播中丢弃了小批量之外的消息。

  • 在前向传播中,我们将  的 临时嵌入设为 ,然后以  的顺序更新  中的历史嵌入 。特别地,在第 层,我们进行以下计算:

  • 在反向传播中,我们将  的临时辅助变量设为 ,然后以  的顺序更新 中的历史辅助变量 。特别地,在第 层,我们进行以下计算:

关于方法的具体细节、详细解释、计算复杂度分析等,请读者参见原论文。

我们分别称  和 为第  层前向传播和反向传播的本地消息补偿

4.理论分析


理论分析包含三个主要定理。从直观(说人话)的角度,它们分别在说:

  • 定理1:Backward SGD 的梯度是无偏的。这样,我们就基本可以保证 Backward SGD 的收敛性。
  • 定理2:LMC 所估计的梯度和 Backward SGD 的梯度相差不大,能够被我们给出的上界所控制。
  • 定理3:LMC 收敛到 GNNs 的一阶驻点,这也是我们的最终定理。

理论部分的核心思想是:LMC 和 Backward SGD 的收敛行为一致。

在本节中,我们做如下假设:

  1. 在第  个迭代中,小批量节点  是从  中均匀采样的,对应的有标签节点集  是从  中采样的。
  2. 函数 , , , , ,  是 -Lipschitz 连续的,其中 。
  3. 范数 , , , , , , , , , , ,  被常数  所控制。

定理1. 假设一个小批量  是从结点集合  中均匀采样的,并且对应的有标签结点集合  也是从  中均匀采样的,则方程 (6) 和 (7)  所计算的小批量梯度  和  是无偏的。定理2. 在上述假设下,令  和 ,存在  和  使得定理3. 在上述假设下,再假定最优值  被  控制。令 ,,以及 ,LMC 可保证在  次迭代后找到一个 -驻点使得 ,其中  是随机从  中选取的,。

5. 实验

在实验部分,我们做了4个大数据集的实验,图3列出了其中3个。LMC 的训练集 loss 的收敛速度超过了所有其他方法,但是测试集上有个众所周知的泛化问题,尽管 LMC 在训练集上收敛快,但模型很快就过拟合了,所以 LMC 在测试集的准确率曲线提升看起来并不如训练集明显。我们发现,LMC 最终的预测准确率和 SOTA 方法 GAS 相差不大,这是因为 early stopping 技术,即在训练过程中,测试集的曲线是震荡的,early stopping 汇报的大概率是测试集准确率曲线的最大值。为了突出训练过程中波动性的影响,我们在图3汇报的曲线用滑动窗口取了平均值构成实线,标准差构成阴影部分。可以看出,LMC 在训练稳定性上明显超过 GAS。GAS 和 LMC 最终预测准确率的差距会在 batch size 比较小的情况下有所体现(图6),这时 METIS 的作用会被削弱。

图3. 收敛时间对比

我们进一步统计达到一个给定测试集准确率的时间。我们先跑一个全梯度 GD,得到它的最高测试集准确率,然后再分别运行几种子图采样方法,计算达到这一准确率需要的时间。在 REDDIT 数据集上,LMC 相对于 GAS 的加速比达到2倍

图4. 达到最高测试集准确率时间对比

更进一步,LMC 的梯度估计和 Backward SGD 差不多,所以梯度估计是更准确的。我们这里统计了计算过程中的相对误差,如图5所示,确实是 LMC 的估计误差最小。

图5. 相对误差对比

进一步做了 small batch size 下的实验,前面在子图采样算法中举了一个例子,子图规模很小的话,丢弃的节点就很多,很容易达到次优。如表三所示,我们的方法对 batch size 更加鲁棒,因此在计算资源受限的情景下,LMC的优势会更加明显

图6. 不同批量大小的表现

最后是消融实验,相对于 SOTA 的 GAS 方法,我们对前向传播过程的补偿消息进行了改进,并且在反向传播也加入了一个补偿。如图7所示,我们发现,在batch size很小的情况下,反向传播的补偿很重要,因为这一 设定下,丢弃了很多消息,导致收敛到次优解。在batch size较大的时候,采样子图一阶邻居是很大的,我们通过采样子图一阶邻居内部的消息传递,提高了历史信息的准确率,也能提高子图采样算法的性能。

图7. 消融实验

参考资料:[1] Hamilton, William L. "Graph representation learning." Synthesis Lectures on Artifical Intelligence and Machine Learning 14.3 (2020): 1-159.[2] Brin, Sergey, and Lawrence Page. "The anatomy of a large-scale hypertextual web search engine." Computer networks and ISDN systems 30.1-7 (1998): 107-117.[3] Fan, Wenqi, et al. "Graph neural networks for social recommendation." The world wide web conference. 2019.[4] Gostick, Jeff, et al. "OpenPNM: a pore network modeling package." Computing in Science & Engineering18.4 (2016): 60-74.[5] Moloi, N. P., and M. M. Ali. "An iterative global optimization algorithm for potential energy minimization." Computational Optimization and Applications 30 (2005): 119-132.[6] Kearnes, Steven, et al. "Molecular graph convolutions: moving beyond fingerprints." Journal of computer-aided molecular design 30 (2016): 595-608.[7] Wang, Zhihai, et al. "Learning Cut Selection for Mixed-Integer Linear Programming via Hierarchical Sequence Model." arXiv preprint arXiv:2302.00244 (2023).[8] Hamilton, Will, Zhitao Ying, and Jure Leskovec. "Inductive representation learning on large graphs." Advances in neural information processing systems 30 (2017). [9] Chen, Jianfei, Jun Zhu, and Le Song. "Stochastic training of graph convolutional networks with variance reduction." arXiv preprint arXiv:1710.10568 (2017).[10] Chen, Jie, Tengfei Ma, and Cao Xiao. "Fastgcn: fast learning with graph convolutional networks via importance sampling." arXiv preprint arXiv:1801.10247 (2018).[11] Zou, Difan, et al. "Layer-dependent importance sampling for training deep and large graph convolutional networks." Advances in neural information processing systems 32 (2019).[12] Huang, Wenbing, et al. "Adaptive sampling towards fast graph representation learning." Advances in neural information processing systems 31 (2018).[13] Chiang, Wei-Lin, et al. "Cluster-gcn: An efficient algorithm for training deep and large graph convolutional networks." Proceedings of the 25th ACM SIGKDD international conference on knowledge discovery & data mining. 2019.[14] Zeng, Hanqing, et al. "Graphsaint: Graph sampling based inductive learning method." arXiv preprint arXiv:1907.04931 (2019).[15] Fey, Matthias, et al. "Gnnautoscale: Scalable and expressive graph neural networks via historical embeddings." International Conference on Machine Learning. PMLR, 2021.[16] Zeng, Hanqing, et al. "Decoupling the depth and scope of graph neural networks." Advances in Neural Information Processing Systems 34 (2021): 19665-19679.[17] Cong, Weilin, et al. "Minimal variance sampling with provable guarantees for fast training of graph neural networks." Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2020.


相关文章
|
6月前
|
机器学习/深度学习 缓存 算法
【论文速递】CVPR2020 - CRNet:用于小样本分割的交叉参考网络
【论文速递】CVPR2020 - CRNet:用于小样本分割的交叉参考网络
|
机器学习/深度学习 自动驾驶 计算机视觉
目标检测落地必备Trick | 结构化知识蒸馏让RetinaNet再涨4个点
目标检测落地必备Trick | 结构化知识蒸馏让RetinaNet再涨4个点
381 0
|
6月前
|
算法 数据挖掘 关系型数据库
有限混合模型聚类FMM、广义线性回归模型GLM混合应用分析威士忌市场和研究专利申请数据
有限混合模型聚类FMM、广义线性回归模型GLM混合应用分析威士忌市场和研究专利申请数据
|
人工智能 算法 图形学
山大SIGGRAPH 2023 最佳论文得主分享:点云法向估计及保特征重建
山大SIGGRAPH 2023 最佳论文得主分享:点云法向估计及保特征重建
229 0
|
6月前
|
机器学习/深度学习 固态存储 算法
目标检测的福音 | 如果特征融合还用FPN/PAFPN?YOLOX+GFPN融合直接起飞,再涨2个点
目标检测的福音 | 如果特征融合还用FPN/PAFPN?YOLOX+GFPN融合直接起飞,再涨2个点
281 0
|
计算机视觉
大连理工卢湖川团队TMI顶刊新作 | M^2SNet: 新颖多尺度模块 + 智能损失函数 = 通用图像分割SOTA网络
大连理工卢湖川团队TMI顶刊新作 | M^2SNet: 新颖多尺度模块 + 智能损失函数 = 通用图像分割SOTA网络
476 0
|
6月前
|
机器学习/深度学习 计算机视觉
【论文速递】MMM2020 - 电子科技大学提出一种新颖的局部变换模块提升小样本分割泛化性能
【论文速递】MMM2020 - 电子科技大学提出一种新颖的局部变换模块提升小样本分割泛化性能
42 0
|
6月前
|
机器学习/深度学习 算法 计算机视觉
【论文速递】CVPR2021 - 基于自引导和交叉引导的小样本分割算法
【论文速递】CVPR2021 - 基于自引导和交叉引导的小样本分割算法
57 0
|
机器学习/深度学习 人工智能 自然语言处理
中山大学团队使用端到端图生成架构进行分子图编辑的逆合成预测
中山大学团队使用端到端图生成架构进行分子图编辑的逆合成预测
171 0
|
人工智能
IJCAI 2022 | 用一行代码大幅提升零样本学习方法效果,南京理工&牛津提出即插即用分类器模块
IJCAI 2022 | 用一行代码大幅提升零样本学习方法效果,南京理工&牛津提出即插即用分类器模块
150 0