cs224w(图机器学习)2021冬季课程学习笔记10 Applications of Graph Neural Networks

简介: cs224w(图机器学习)2021冬季课程学习笔记10 Applications of Graph Neural Networks

本章主要内容:

本章继续上一章1内容,讲design space剩下的两部分:图增强,如何训练一个GNN模型(GNN训练全流程)。


在图增强方面:

首先介绍图增强的原因和分类。

然后分别介绍:

graph feature augmentation的方法:使用常数特征、独热编码、图结构信息

graph structure augmentation的方法:

 对稀疏图:增加虚拟边或虚拟节点

 对稠密图:节点邻居抽样


接下来讲GNN模型训练的学习目标。

首先介绍不同粒度任务下的prediction head(将节点嵌入转换为最终预测向量):节点级别的任务可以直接进行线性转换。链接级别的任务可以将节点对的嵌入进行concatenation或点积后进行线性转换。图级别的任务是将图中所有节点嵌入作池化操作,可以通过hierarchical global pooling方法来进行优化(实际应用:DiffPool)。

接下来介绍了预测值和标签的问题:有监督/无监督学习情况下的标签来源。

然后介绍损失函数:分类常用交叉熵2,回归任务常用MSE(L2 loss)。

接下来介绍评估指标:回归任务常用RMSE和MAE,分类任务常用accuracy和ROC AUC。

最后讲了设置GNN预测任务(将图数据拆分为训练/验证/测试集)的方法,分为transductive和inductive两种。


1. Graph Augmentation for GNNs


这一部分在 Lecture 71 的slides中写过,但是在 Lecture 8(本章)课程中讲的,所以我笔记也放在这一部分来做。


  1. 回顾一遍在 Lecture 71 第一节中讲过的GNN图增强部分:

image.png


  1. 为什么要进行图增强?

我们在之前的学习过程中都假设原始数据和应用于GNN的计算图一致,但很多情况下原始数据可能不适宜于GNN:

  • 特征层面:输入图可能缺少特征(也可能是特征很难编码)→特征增强
  • 结构层面:
  1. 图可能过度稀疏→导致message passing效率低(边不够嘛)
  2. 图可能过度稠密→导致message passing代价太高(每次做message passing都需要对好几个节点做运算)
  3. 图可能太大→GPU装不下
  • 事实上输入图很难恰好是适宜于GNN(图数据嵌入)的最优计算图

image.png


  1. 图增强方法
  • 图特征:输入图缺少特征→特征增强
  • 图结构:

1)图过于稀疏→增加虚拟节点/边

2)图过于稠密→在message passing时抽样邻居

3)图太大→在计算嵌入时抽样子图(在后续课程中会专门介绍如何将GNN方法泛化到大型数据上scale up)

image.png


1.1 图特征增强Feature Augmentation

  1. 应对图上缺少特征的问题(比如只有邻接矩阵),标准方法:
  • constant:给每个节点赋常数特征

image.png

  • one-hot:给每个节点赋唯一ID,将ID转换为独热编码向量的形式(即ID对应索引的元素为1,其他元素都为0)

image.png

  • 两种方法的比较:

image.png

image.png


  1. 应对GNN很难学到特定图结构的问题(如果不用特征专门加以区分,GNN就学不到这些特征):
  • 举例:节点所处环上节点数cycle count这一属性

问题:因为度数相同(都是2),所以无论环上有多少个节点,GNN都会得到相同的计算图(二叉树),无法分别。

解决方法:加上cycle count这一特征(独热编码向量,节点数对应索引的元素为1,其他元素为0)。

image.png

image.png

image.png

  • 其他常用于数据增强的特征:clustering coefficient,centrality(及任何 Lecture 23 中讲过的特征),PageRank4等

image.png


1.2 图结构增强Structure Augmentation

  1. 对稀疏图:增加虚拟边virtual nodes或虚拟节点virtual edges
  • image.png

image.png

  • 虚拟节点:增加一个虚拟节点,这个虚拟节点与图(或者一个从图中选出的子图)上的所有节点相连

这会导致所有节点最长距离变成2(节点A-虚拟节点-节点B)

优点:稀疏图上message passing大幅提升

image.png


  1. 对稠密图:节点邻居抽样node neighborhood sampling7

在message passing的过程中,不使用一个节点的全部邻居,而改为抽样一部分邻居。

image.png

举例来说,对每一层,在传播信息时随机选2个邻居,计算图就会从上图变成下图:

image.png

优点:计算图变小

缺点:可能会损失重要信息(因为有的邻居直接不用了嘛)


可以每次抽样不同的邻居,以增加模型鲁棒性:

image.png


  1. 节点邻居抽样示例8

我们希望经抽样后,结果跟应用所有邻居的结果类似,但还能高效减少计算代价(在后续课程中会专门介绍如何将GNN方法泛化到大型数据上scale up)。

实践证明效果很好。

image.png


2. Learning Objective


  1. 回顾一遍在 Lecture 71 第一节中讲过的学习目标部分:我们如何训练一个GNN模型?

image.png

  1. GNN训练pipeline

输入数据→用GNN训练数据→得到节点嵌入→prediction head(在不同粒度的任务下,将节点嵌入转换为最终需要的预测向量)→得到预测向量和标签→选取损失函数→选取评估指标

(前三部分已经在本章及上章前文讲述过)

image.png


2.1 Prediction Head

  1. 不同粒度下的prediction head:节点级别,边级别,图级别

image.png

  1. image.png

image.png

image.png


  1. image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

一个hierarchical pooling的实际应用:DiffPool10(惯例,我又没咋看懂)

大致来说,就是每一次先用一个GNN计算节点嵌入,然后用另一个GNN(这两个GNN可以同步运算)(两个GNN联合训练jointly train)计算节点属于哪一类,然后按照每一类对图进行池化。每一类得到一个表示向量,保留类间的链接,产生一个新的图。重复这一过程,直至得到最终的表示向量。

将图池化问题与社区发现问题相结合,用节点嵌入识别社区→聚合社区内的节点得到community embeddings→用community embeddings识别supercommunity→聚合supercommunity内的节点得到supercommunity embeddings……

image.png

image.png


2.2 Predictions & Labels

  1. 有监督问题的标签 & 无监督问题的信号

image.png

  1. 有监督学习supervise learning:直接给出标签(如一个分子图是药的概率)

无监督学习unsupervised learning / self-supervised learning:使用图自身的信号(如链接预测:预测两节点间是否有边)

有时这两种情况下的分别比较模糊,在无监督学习任务中也可能有“有监督任务”,如训练GNN以预测节点clustering coefficient

image.png


  1. 有监督学习的标签:按照实际情况而来

举例:

节点级别——引用网络中,节点(论文)属于哪一学科

边级别——交易网络中,边(交易)是否有欺诈行为

图级别——图(分子)是药的概率


建议将无监督学习任务规约到三种粒度下的标签预测任务,因为这种预测任务有很多已做过的工作可资参考,会好做些。

例如聚类任务可视为节点属于某一类的预测任务。

image.png


  1. 无监督学习的信号:

在没有外部标签时,可以使用图自身的信号来作为有监督学习的标签。举例来说,GNN可以预测:

节点级别:节点统计量(如clustering coefficient3, PageRank4 等)

边级别:链接预测(隐藏两节点间的边,预测此处是否存在链接)

图级别:图统计量(如预测两个图是否同构)

这些都是不需要外部标签的

image.png


2.3 损失函数Loss Function

  1. 分类任务常用交叉熵,回归任务常用MSE

image.png

image.png

image.png

image.png

image.png

image.png

image.png

  1. 此外还有其他损失函数,如maximum margin loss,适用于我们关心节点顺序、不关心具体数值而关心其排行的情况。


2.4 评估指标Evaluation Metrics

  1. evaluation metrics12:Accuracy和ROC AUC

image.png

  1. 回归任务

image.png

image.png


  1. 分类任务
  • image.png
  • 二分类任务

对分类阈值敏感的评估指标:

(如果输出范围为 [0,1],我们用0.5作为阈值)

accuracy

precision / recall

(因为数据不平衡时可能会出现accuracy虚高的情况。比如99%的样本都是负样本,那么分类器只要预测所有样本为负就可以获得99%的accuracy,但这没有意义。所以需要其他评估指标来解决这一问题)


对分类阈值不敏感的评估指标:ROC AUC

image.png


二元分类的评估指标(可参考 sklearn.metrics.classification_report — scikit-learn 0.24.2 documentation):

accuracy(分类正确的观测占所有观测的比例)

precision(预测为正的样本中真的为正(预测正确)的样本所占比例)

recall(真的为正的样本中预测为正(预测正确)的样本所占比例)

F1-Score(precision和recall的调和平均值,信息抽取、文本挖掘等领域常用)

混淆矩阵

image.png


ROC曲线:TPR(recall)和FPR之间的权衡(对角斜线说明是随机分类器)

image.png

ROC AUC

ROC曲线下面积。越高越好,0.5是随机分类器,1是完美分类器。

随机抽取一个正样本和一个负样本,正样本被识别为正样本的概率比负样本被识别为正样本的概率高的概率。

image.png


2.5 切分数据集

  1. 将数据集切分为训练集、验证集、测试集

image.png


  1. fixed / random split

fixed split:只切分一次数据集,此后一直使用这种切分方式

random split:随机切分数据集,应用多次随机切分后计算结果的平均值


  1. 我们希望三部分数据之间没有交叉,即留出法hold-out data13。

但由于图结构的特殊性,如果直接像普通数据一样切分图数据集,我们可能不能保证测试集隔绝于训练集:就是说,测试集里面的数据可能与训练集里面的数据有边相连,在message passing的过程中就会互相影响,导致信息泄露。

image.png

image.png

image.png


  1. 解决方式1:transductive setting

输入全图在所有split中可见。仅切分(节点)标签。

image.png


  1. 解决方式2:inductive setting

去掉各split之间的链接,得到多个互相无关的图。这样不同split之间的节点就不会互相影响。

image.png


  1. transductive setting / inductive setting
  • transductive setting:

①测试集、验证集、训练集在同一个图上,整个数据集由一张图构成

②全图在所有split中可见。

③仅适用于节点/边预测任务。


  • inductive setting:

①测试集、验证集、训练集分别在不同图上,整个数据集由多个图构成。

②每个split只能看到split内的图。成功的模型应该可以泛化到没见过的图上。

③适用于节点/边/图预测任务。

image.png


  1. 示例:节点分类任务

transductive:各split可见全图结构,但只能观察到所属节点的标签

inductive:切分多个图,如果没有多个图就将一个图切分成3部分、并去除各部分之间连接的边

image.png


  1. 示例:图预测任务

只适用inductive setting,将不同的图划分到不同的split中。

image.png


  1. 示例:链接预测任务

任务目标:预测出缺失的边。

这是个 unsupervised / self-supervised 任务,需要自行建立标签、自主切分数据集。

需要隐藏一些边,然后让GNN预测边是否存在。

image.png


在切分数据集时,我们需要切分两次。

第一步:在原图中将边分为message edges(用于GNN message passing)和supervision edges(作为GNN的预测目标)。只留下message edges,不将supervision edges传入GNN。

image.png


第二步:切分数据集

 方法1:inductive link prediction split

 划分出3个不同的图组成的split,每个split里的边按照第一步分成message edges和supervision edges

image.png

image.png

 方法2:transductive link prediction split(链接预测任务的默认设置方式)

 在一张图中进行切分:在训练时要留出验证集/测试集的边,而且注意边既是图结构又是标签,所以还要留出supervision edges(要不然还搞啥呢……)

 具体来说:

 训练:用 training message edges 预测 training supervision edges

 验证:用 training message edges 和 training supervision edges 预测 validation edges

 测试:用 training message edges 和 training supervision edges 和 validation edges 预测 test edges

 是个链接越来越多,图变得越来越稠密的过程。这是因为在训练过程之后,supervision edges就被GNN获知了,所以在验证时就要应用 supervision edges 来进行 message passing(测试过程逻辑类似)

image.png

image.png

image.png

image.png

image.png

不同论文中链接预测的数据集切分方式可能不同。


2.6 GNN Training Pipeline

image.png


3. GNN design space 总结


image.png


相关文章
|
2月前
|
机器学习/深度学习 计算机视觉 Python
模型预测笔记(三):通过交叉验证网格搜索机器学习的最优参数
本文介绍了网格搜索(Grid Search)在机器学习中用于优化模型超参数的方法,包括定义超参数范围、创建参数网格、选择评估指标、构建模型和交叉验证策略、执行网格搜索、选择最佳超参数组合,并使用这些参数重新训练模型。文中还讨论了GridSearchCV的参数和不同机器学习问题适用的评分指标。最后提供了使用决策树分类器进行网格搜索的Python代码示例。
158 1
|
4月前
|
机器学习/深度学习 算法 Python
【绝技揭秘】Andrew Ng 机器学习课程第十周:解锁梯度下降的神秘力量,带你飞速征服数据山峰!
【8月更文挑战第16天】Andrew Ng 的机器学习课程是学习该领域的经典资源。第十周聚焦于优化梯度下降算法以提升效率。课程涵盖不同类型的梯度下降(批量、随机及小批量)及其应用场景,介绍如何选择合适的批量大小和学习率调整策略。还介绍了动量法、RMSProp 和 Adam 优化器等高级技巧,这些方法能有效加速收敛并改善模型性能。通过实践案例展示如何使用 Python 和 NumPy 实现小批量梯度下降。
45 1
|
6月前
|
机器学习/深度学习 算法 BI
机器学习笔记(一) 感知机算法 之 原理篇
机器学习笔记(一) 感知机算法 之 原理篇
|
6月前
|
机器学习/深度学习 算法 数据可视化
技术心得记录:机器学习笔记之聚类算法层次聚类HierarchicalClustering
技术心得记录:机器学习笔记之聚类算法层次聚类HierarchicalClustering
69 0
|
6月前
|
机器学习/深度学习 分布式计算 API
技术好文:Spark机器学习笔记一
技术好文:Spark机器学习笔记一
48 0
|
7月前
|
机器学习/深度学习 自然语言处理 PyTorch
fast.ai 机器学习笔记(四)(1)
fast.ai 机器学习笔记(四)
143 1
fast.ai 机器学习笔记(四)(1)
|
7月前
|
机器学习/深度学习 数据挖掘 Python
fast.ai 机器学习笔记(一)(4)
fast.ai 机器学习笔记(一)
133 1
fast.ai 机器学习笔记(一)(4)
|
7月前
|
机器学习/深度学习 Python 文件存储
fast.ai 机器学习笔记(一)(3)
fast.ai 机器学习笔记(一)
140 1
fast.ai 机器学习笔记(一)(3)
|
7月前
|
机器学习/深度学习 监控 算法
LabVIEW使用机器学习分类模型探索基于技能课程的学习
LabVIEW使用机器学习分类模型探索基于技能课程的学习
58 1
|
7月前
|
机器学习/深度学习 Python 索引
fast.ai 机器学习笔记(二)(4)
fast.ai 机器学习笔记(二)
62 0
fast.ai 机器学习笔记(二)(4)