NeurIPS 2022 | 直面图的复杂性,港中文等提出面向图数据分布外泛化的因果表示学习(1)

简介: NeurIPS 2022 | 直面图的复杂性,港中文等提出面向图数据分布外泛化的因果表示学习

NeurIPS 2022 | 直面图的复杂性,港中文等提出面向图数据分布外泛化的因果表示学习

机器之心 2022-12-26 12:52 发表于北京

机器之心专栏

作者:Yongqiang Chen


随着深度学习模型的应用和推广,人们逐渐发现模型常常会利用数据中存在的虚假关联(Spurious Correlation)来获得较高的训练表现。但由于这类关联在测试数据上往往并不成立,因此这类模型的测试表现往往不尽如人意 [1]。其本质是由于传统的机器学习目标(Empirical Risk Minimization,ERM)假设了训练测试集的独立同分布特性,而在现实中该独立同分布假设成立的场景往往有限。在很多现实场景中,训练数据的分布与测试数据分布通常表现出不一致性,即分布偏移(Distribution Shifts),旨在提升模型在该类场景下性能的问题通常被称为分布外泛化(Out-of-Distribution)问题。关注学习数据中的相关性而非因果性的 ERM 等一类方法往往难以应对分布偏移。尽管近年涌现了诸多方法借助因果推断(Causal Inference)中的不变性原理(Invariance Principle)在分布外泛化(Out-of-Distribution)问题上取得了一定的进展,但在图数据上的研究依然有限。这是因为图数据的分布外泛化比传统的欧式数据更加困难,给图机器学习带来了更多的挑战。本文以图分类任务为例,对借助因果不变性原理的图分布外泛化进行了探究。



近年来,借助因果不变性原理,人们在欧式数据的分布外泛化问题中取得了一定的成功,但对图数据的研究仍然有限。与欧式数据不同,图的复杂性对因果不变性原理的使用以及克服分布外泛化难题提出了独特的挑战。


为了应对该挑战,我们在本工作中将因果不变性融入到图机器学习中,并提出了因果启发的不变图学习框架,为解决图数据的分布外泛化问题提供了新的理论和方法。


论文已在 NeurIPS 2022 发表,本工作由香港中文大学、香港浸会大学, 腾讯 AI Lab 以及悉尼大学合作完成。




图数据的分布外泛化


图数据的分布外泛化难在哪?


图神经网络近年来在涉及图结构的机器学习应用,如推荐系统、AI 辅助制药等领域,取得了很大的成功。然而,因现有的大部分的图机器学习算法都依赖于数据的独立同分布假设,使得当测试数据和训练数据出现偏移(Distribution Shifts)时,算法的性能会极大下降。同时,因为图数据结构的复杂性,导致图数据的分布外泛化相比于欧式数据更普遍且更具挑战性。


图 1. 图上的分布偏移示例。


首先,图数据的分布偏移可以出现在图的节点特征分布中(Attribute-level Shifts)。例如,在推荐系统中,训练数据涉及的商品可能采自一些比较流行的类别,涉及到的用户也可能来自于某些特定地区,而在测试阶段,系统则需要妥善处理所有类别以及地区的用户和商品 [2,3,4]。此外,图数据的分布偏移还可以出现在图的结构分布中(Structure-level Shifts)。早在 2019 年,人们就注意到,在较小的图上进行训练得到的图神经网络难以学到有效的注意力(Attention)权重以泛化到更大的图上 [5],这也推动了一系列相关工作的提出 [6,7]。在现实场景中,这两类分布偏移往往可能同时出现,并且这些不同层级的分布偏移还可以和所要预测的标签具有不同的虚假关联模式。如在推荐系统中,来自特定类别的商品与特定地区的用户往往会在商品用户交互图上展现独特的拓扑结构 [4]。在药物分子属性预测中,训练时涉及的药物分子可能偏小,同时预测的结果也会受到实验测定环境的影响 [8]。


此外,欧式空间的分布外泛化往往会假设数据来自于多个环境(Environment)或者域(Domain),并进一步假设训练期间模型能够获取训练数据中每个样本所属的环境,以此来发掘跨越环境的不变性。然而,要获得数据的环境标签往往需要和数据相关的一些专家知识,而由于图数据的抽象性,使得图数据的环境标签获得更加昂贵。因此,大部分现有的图数据集如 OGB 都不含此类环境标签信息,即便少部分如 DrugOOD 数据集存在环境标签,但也存在不同程度的噪声。


现有方法能否解决图上的分布外泛化问题?


为了对图数据分布外泛化的挑战有一个直观的理解,我们基于 Spurious-Motif [9] 数据集构建新的数据以进一步实例化上述几大挑战,并尝试使用现有的方法如欧式数据上分布外泛化的训练目标 IRM [10],或者具有更强表达能力的 GNN [11],分析能否通过已有的方法解决图数据的分布外泛化问题。


图 2. Spurious Motif 数据集示例。


Spurious Motif 任务如图 2 所示,主要根据输入的图中是否含有特定结构的子图(如 House,或者 Cycle)对图标签进行判断,其中节点颜色代表节点的属性。使用该数据集可以比较清晰地测试不同层级的分布偏移对图神经网络性能的影响。对于一个使用 ERM 进行训练的普通 GNN 模型:


  • 如果训练阶段大部分有 House 子图的样本都节点大部分绿色,而 Cycle 则是蓝色,那么在测试阶段,模型则倾向于预测任何含大量绿色节点的图为 “House”,而蓝色节点的图为 “Cycle”。
  • 如果训练阶段大部分有 House 子图的样本都与一个六边形子图共同出现,那么在测试阶段,模型则倾向于判定任何含有六边形结构的图为 “House”。


此外,模型在训练时无法获得任何和环境标签相关的信息,得到实验结果如图 3 所示(更多结果可以查阅论文附录 D)。


图 3. 现有方法在不同图分布偏移下的表现。


如图 3 所示,普通的 GCN 不论是在使用 ERM 或者 IRM 训练,都无法应对图的结构偏移(Struc);而在增加了图节点属性偏移(Mixed)以及图大小分布偏移后(图 3 中),模型性能将进一步降低;此外即便使用具有更强表达能力的 kGNN 也难以避免严重的性能损失(平均性能的降低,或更大的方差)。


由此,我们自然地引出所要研究的问题:如何才能获得一个具有应对多种图分布偏移的 GNN 模型?


面向图数据分布外泛化的因果模型


为了解决上述问题,我们需要对学习目标,即不变图神经网络(Invariant GNN),进行定义,即在最糟糕的环境下仍旧表现良好的模型(严谨的定义参见论文):


定义 1(不变图神经网络)给定一系列收集自不同的具有因果关联的环境的图分类数据集,其中包含被认为是来自环境 e 的独立同分布样本,考虑一个图神经网络,其中分别是作为输入的图空间和样本空间,f 是不变图神经网络,当且仅当,即最小化所有环境的最坏经损失 (worst empirical risk),其中为模型在环境中的经验损失。


模型在训练时只能获得部分的训练环境中的数据,如果不对数据的过程进行任何假设,不变图神经网络定义所要求的 minmax 最优性是很难做到的。因此,我们从因果推断(Causal Inference)的角度使用因果模型(Structural Causal Model)对图的生成过程进行建模,并对环境之间的关联进行刻画,以尝试定义图数据上的因果不变性。


图 4. 图数据生成过程的因果模型。


不失一般性,我们将所有影响图生成的隐变量纳入隐空间,并将图的生成过程建模为。此外,对于隐变量,根据其是否受环境 E 影响,我们将其划分成不变隐变量(invariant latent variable)以及虚假隐变量(spurious latent variable)。对应地,隐变量 C 与 S 分别会影响 G 的某个子图的生成,分别记作不变子图以及虚假子图,如图 4 (a) 所示,而 C 主要控制了图的标签 Y。这也可以进一步推出,即 C 与 Y 相比于 S 有更高的互信息。这样的生成过程与许多实际例子相对应,如一个分子的药化属性通常由某个关键的基团(分子子图)决定(如羟基 - HO 之于分子的水溶性)。


此外,C 与Y,S以及 E 在隐空间有多种类型的交互,主要跟进虚假隐变量 S 与标签 Y 是否在有不变隐变量 C 之外额外的关联,即,可以概括为两种:如图 4 (b) 的 FIIF(Fully Informative Invariant Feature)以及图 4 (c) 的 PIIF(Partially Informative Invariant Feature)。其中 FIIF 表示给定不变信息后标签与虚假相关量独立。PIIF 则相反。需要说明的是,为了尽可能地覆盖更多的图分布偏移,我们的因果模型致力于对各种图生成模型的广泛的建模。如有更多关于图生成过程的知识,图 4 所示的因果模型则可以进一步泛化到更具体的例子。如在附录 C.1 中,我们展示了如何通过增加额外图极限(graphon)的假设,将因果图泛化至先前 Bevilacqua 等人用于分析图大小分布偏移的工作 [7]。


基于上述的因果分析,我们可以知道,当模型只使用不变子图 进行预测的时,即只使用之间的关联,模型的预测才不会受到环境 E 的改变而影响;反之,如果模型的预测依赖于任何与 S 或有关的信息,其预测结果将会因为 E 的变化发生极大的改变,从而出现性能损失。因此,我们的目标可以从学习一个不变图神经网络,进一步细化至:a) 识别潜在的不变子图;b) 用识别的子图预测 Y。为了进一步与数据生成的算法过程相对应,我们进一步把图神经网络拆分为子图识别网络(Featurizer GNN)和分类网络(Classifier GNN),且,其中的子图空间。那么模型的学习目标则可表示为如公式 (1) 所示:



其中,,为子图识别网络对不变子图的预测;与 Y 互信息,通常,最大化可以通过最小化使用预测 Y 的经验损失实现。然而,由于 E 的缺失,我们难以直接使用 E 对进行独立性的验证,为此,我们必须寻求其他等价条件以识别需要的不变子图。


因果启发的不变图学习


为了解决在缺失时的不变子图识别问题,基于公式 (1) 的框架,我们希望寻求一个公式 (1) 的易于实现的等价条件。特别地,我们首先考虑一种比较简单的情况,即潜在的不变子图大小固定且已知,在这样的条件下,考虑最大化,尽管有同样的大小,但因为与 Y 也存在关联,所以在没有任何其他约束的情况下,最大化可能会使得估计得到的不变子图中包含部分与 Y 有互信息的虚假子图。


为了将可能的虚假子图部分 “挤” 出去,我们将进一步从因果模型中寻求更多关于特有的属性意到,不论是 PIIF 还是 FIIF 的虚假关联类型,对于最大化与标签 Y 互信息的子图,我们有:


  • 不同环境中与相同不变隐变量 C 的不变子图是这两个环境中互信息最大的两个子图,即
  • 同一个环境中对应不同不变隐变量 C 的不变子图两个不变子图是这个环境中互信息最小的两个子图,即

结合上述两个性质,我们可以推出


由于在实践中我们难以直接观察得到,我们则可以通过作为在公式 (2) 中的代理使用。


同时,当同时达到最大化时,将自动最小化,否则模型的预测将坍缩至平凡解。由此,我们得到了在一种简单情况下的不变子图等价条件,结合公式 (1),我们得到了第一版因果启发的不变图学习(Causality-inspired Invariant Graph leArning)框架,即 CIGAv1:

其中,,即与 G 来自同个类别 Y。我们在论文中进一步证明了 CIGAv1 在已知图大小情况下能成功识别图 4 对应的因果模型中潜在的不变子图。然而,由于先前的假设过于理想化,在实践中,不变子图的大小可能会发生改变同时对应的大小我们也往往无法得知。在没有子图大小的假设下,只需要将全图识别为不变子图即能满足 CIGAv1 的要求。因此,我们考虑进一步寻求关于不变子图别的性质用于去除这一假设。


注意到,在最大化时,可能出现中的虚假子图部分与被去除的不变子图部分享有同样的和相关的互信息。那么,我们能否反其道而行之,同时最大化以去除中可能的虚假子图部分呢?答案是肯定的,我们可以利用与 Y 的关联令其与的估计互相竞争。需要注意的是,在最大化时需要保证不会超过,否则将预测的又将陷入平凡解。结合这一额外的条件,我们则可以将关于不变子图大小的假设从公式 (3) 去除,得到如下 CIGAv2:

图 5. 因果启发的不变图学习框架示意。



相关文章
|
9天前
|
自然语言处理 测试技术 计算机视觉
ICLR 2024:谁说大象不能起舞! 重编程大语言模型实现跨模态交互的时序预测
【4月更文挑战第22天】**TIME-LLM** 论文提出将大型语言模型重编程用于时序预测,克服数据稀疏性问题。通过文本原型重编码和Prompt-as-Prefix策略,使LLMs能处理连续时序数据。在多基准测试中超越专业模型,尤其在少量样本场景下效果突出。但面临跨领域泛化、模型调整复杂性和计算资源需求的挑战。[论文链接](https://openreview.net/pdf?id=Unb5CVPtae)
24 2
|
13天前
多水平模型、分层线性模型HLM、混合效应模型研究教师的受欢迎程度
多水平模型、分层线性模型HLM、混合效应模型研究教师的受欢迎程度
12 1
|
13天前
|
数据可视化 数据挖掘
singleCellNet(代码开源)|单细胞层面对细胞分类进行评估,褒贬不一,有胜于无
`singleCellNet`是一款用于单细胞数据分析的R包,主要功能是进行细胞分类评估。它支持多物种和多分组分析,并提供了一个名为`CellNet`的类似工具的示例数据集。用户可以通过安装R包并下载测试数据来运行demo。在demo中,首先加载查询和测试数据,然后训练分类器,接着进行评估,包括查看准确率和召回率的曲线图、分类热图和比例堆积图等。此外,`singleCellNet`还支持跨物种评估,将人类基因映射到小鼠直系同源物进行分析。整体而言,`singleCellNet`是一个用于单细胞分类评估的综合工具,适用于相关领域的研究。
31 6
|
13天前
|
算法 数据挖掘 关系型数据库
有限混合模型聚类FMM、广义线性回归模型GLM混合应用分析威士忌市场和研究专利申请数据
有限混合模型聚类FMM、广义线性回归模型GLM混合应用分析威士忌市场和研究专利申请数据
17 0
|
14天前
R语言如何用潜类别混合效应模型(LCMM)分析抑郁症状
R语言如何用潜类别混合效应模型(LCMM)分析抑郁症状
19 0
|
14天前
|
机器学习/深度学习 数据可视化 算法
R语言贝叶斯广义线性混合(多层次/水平/嵌套)模型GLMM、逻辑回归分析教育留级影响因素数据
R语言贝叶斯广义线性混合(多层次/水平/嵌套)模型GLMM、逻辑回归分析教育留级影响因素数据
46 7
|
3月前
|
计算机视觉
模型落地必备 | 南开大学提出CrossKD蒸馏方法,同时兼顾特征和预测级别的信息
模型落地必备 | 南开大学提出CrossKD蒸馏方法,同时兼顾特征和预测级别的信息
38 0
|
10月前
|
算法
基于模态凝聚算法的特征系统实现算法的自然激励技术(Matlab代码实现)
基于模态凝聚算法的特征系统实现算法的自然激励技术(Matlab代码实现)
|
7月前
|
机器学习/深度学习 算法
如何解决图神经网络过相关?一个IBM的新视角!
如何解决图神经网络过相关?一个IBM的新视角!
|
11月前
|
机器学习/深度学习 算法 知识图谱
浙大团队将化学知识引入机器学习,提出可外推、可解释的分子图模型预测反应性能
浙大团队将化学知识引入机器学习,提出可外推、可解释的分子图模型预测反应性能
142 0