元图:通过元学习进行小样本的链接预测

简介: 元图:通过元学习进行小样本的链接预测

image.png

今天给大家介绍McGill University和Uber AI一起在 NeurIPS 2020发表的一篇关于元学习的论文。该论文针对目前链接预测的方法不能够在多图的情况下有效的传递、利用图中的信息,以及不能从稀疏图中有效的学习这两个问题,提出了一个基于梯度下降的元学习框架——Meta-Graph。实验结果显示,对于多图和稀疏图的情况而言,MetaGraph相比于现有的模型有更好的效果。


1


简介


常见的深度学习模型,目的是学习一个用于预测的数学模型。而元学习面向的不是学习的结果,而是学习的过程。其学习的不是一个直接用于预测的数学模型,而是学习“如何更快更好地学习一个数学模型”,简单来说就是学习如何学习。


对于一个有着结点和边的表示的图,链接预测的目的是去学习这个图,然后推断结点之间目前未知的边,以达到预测的目的。例如在社交网络中,我们可以使用链接预测来增强友谊推荐系统,或者在生物网络数据的情况下,利用链接预测来推断药物,蛋白质,疾病之间可能的关系。


目前的主流的链接预测方法的一个特点就是,这些工作通常指关注一个特定的问题设置:它通常假定链接预测将在单个大图上执行,并且该图相对完整。而在这项工作中,作者希望可以通过元学习,从多个图(每个图仅仅包含完整图的小部分数据)上进行链接预测。


2


主要贡献


Meta-Graph是基于梯度下降的元学习方法。作者把图上的分布看作是任务的分布(也就是一幅图看成是一个任务。多个任务组成我们拥有的全部数据),对于每一个任务,使用的模型是可以进行few-shot链接预测的图神经网络VAGE。从不同的任务中可以学习到一组全局初始化参数。


为了使得模型更加迅速的适应新的任务,还引入了图签名函数(Graph signature function),用于将每个图的结构映射成为VAGE的初始输入:。


利用全局初始化参数和图签名函数来初始化VGAE的推理模型。该论文的模型框架如下图右图所示,左图为MAML模型与Meta-Graph模型的对比:

image.png

Meta-Graph 背后关键的思想是,使用基于梯度的元学习来优化VGAE推理模型中的全局初始化参数,同时还学习了调制图形中参数初始化的编码函数。通过不同的任务不断去完善全局初始化参数和图签名函数,最后可以利用这两个组件,在新的任务上实现更加优异的表现。主要的算法如下图所示:

image.png

3


实验


作者设计了三种新颖的基准测试,以实现few-shot链接预测任务。


3.1、模型表现 对于每个图,只选取其部分边以及边所对应的结点(作者设置了取10%,20%,30%三种情况),以构造出稀疏图。下图展示了该模型在不同程度的稀疏图下的平均AUC表现。总体而言,Meta-Graph在除了一种设置外的其他所有设置中均实现了最高的平均AUC。与经典的基于梯度下降的元学习方法MAML相比较,平均AUC提高了4.8%,与Finetune基准相比,则提高了5.3%。值得一提的是,对于每个图,仅使用10%的图边缘进行训练时,Meta-Graph表现出特别强劲的性能,这突出了作者的框架可以从稀疏图中有效的学习。

image.png

3.2、新任务的适应性 下图展示了通过Meta-Graph,在一组稀疏图训练数据中,仅执行5次梯度更新后的平均AUC。可以看到,Meta-Graph的性能与MAML相比提高了9.4%,和Finetune相比平均提高了8.0%。这突出显示了,图不仅仅可以从稀疏的边缘样本中学习,而且还可以仅使用少量的梯度步骤就可以快速学习新的数据。

image.png

4


讨论


作者设计了Meta-Graph框架来解决few-shot链接预测的问题。该框架采用基于梯度的元学习来优化局部链接预测模型的全局初始化参数,同时还学习每个图的编码函数(图签名函数)。根据实验结果可以知道,在三个截然不用的基线任务上,Meta-Graph取得了比较好的表现。就该方法的局限性而言,作者认为一个关键的局限性是图形签名函数仅限于通过当前图形的编码来调制本地链接预测模型,而该图形并未明确捕获数据集中图形之间的成对相似性。可以通过学习图之间的相似性度量或者内核来拓展原图,然后将其用于条件元学习。


目录
相关文章
|
5月前
|
机器学习/深度学习 算法
【阿旭机器学习实战】【30】二手车价格预估--KNN回归案例
【阿旭机器学习实战】【30】二手车价格预估--KNN回归案例
|
6月前
|
机器学习/深度学习 存储 数据采集
随机森林填充缺失值、BP神经网络在亚马逊评论、学生成绩分析研究2案例合集1
随机森林填充缺失值、BP神经网络在亚马逊评论、学生成绩分析研究2案例合集
|
6月前
|
机器学习/深度学习 算法 Python
R语言VaR市场风险计算方法与回测、用LOGIT逻辑回归、PROBIT模型信用风险与分类模型
R语言VaR市场风险计算方法与回测、用LOGIT逻辑回归、PROBIT模型信用风险与分类模型
|
6月前
R语言参数检验 :需要多少样本?如何选择样本数量
R语言参数检验 :需要多少样本?如何选择样本数量
|
6月前
|
机器学习/深度学习 算法
R语言非参数方法:使用核回归平滑估计和K-NN(K近邻算法)分类预测心脏病数据
R语言非参数方法:使用核回归平滑估计和K-NN(K近邻算法)分类预测心脏病数据
|
6月前
|
机器学习/深度学习 算法 数据挖掘
survey和surveyCV:如何用R语言进行复杂抽样设计、权重计算和10折交叉验证?
survey和surveyCV:如何用R语言进行复杂抽样设计、权重计算和10折交叉验证?
304 1
|
6月前
|
机器学习/深度学习 测试技术
用R语言实现神经网络预测股票实例
用R语言实现神经网络预测股票实例
|
机器学习/深度学习 算法
分类预测 | MATLAB实现MIV-SVM的平均影响值MIV算法结合支持向量机分类预测
分类预测 | MATLAB实现MIV-SVM的平均影响值MIV算法结合支持向量机分类预测
|
机器学习/深度学习 人工智能 算法
让模型训练速度提升2到4倍,「彩票假设」作者的这个全新PyTorch库火了
让模型训练速度提升2到4倍,「彩票假设」作者的这个全新PyTorch库火了
154 0
让模型训练速度提升2到4倍,「彩票假设」作者的这个全新PyTorch库火了
|
机器学习/深度学习 PyTorch 算法框架/工具
DeepTime:时间序列预测中的元学习模型
DeepTime,是一个结合使用元学习的深度时间指数模型。通过使用元学习公式来预测未来,以应对时间序列中的常见问题(协变量偏移和条件分布偏移——非平稳)。该模型是时间序列预测的元学习公式协同作用的一个很好的例子。
438 0