cs224w(图机器学习)2021冬季课程学习笔记19 Deep Generative Models for Graphs

简介: 本章主要内容:首先介绍了深度图生成模型的基本情况,然后介绍了直接从图数据集中学习的GraphRNN模型1,最后介绍了医药生成领域的GCPN模型

1. Deep Generative Models for Graphs


对深度图生成模型,有两种看待问题的视角:


第一种是说,图生成任务很重要,我们此前已经学习过传统图生成模型3,接下来将介绍在图表示学习框架下如何用深度学习的方法来实现图生成任务。


另一种视角是将其视为图表示学习任务的反方向任务。

课程此前学习过的图表示学习任务4 deep graph encoders:输入图数据,经图神经网络输出节点嵌入

image.png


而深度图生成模型可以说是deep graph decoders:输入little noise parameter或别的类似东西,输出图结构数据

image.png


2. Machine Learning for Graph Generation


  1. 图生成任务分为两种:
  • realistic graph generation

生成与给定的一系列图相似的图(本章2、3节重点)

  • goal-directed graph generation

生成优化特定目标或约束的图(举例:生成/优化药物分子)(本章第4节介绍)

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png

image.png


3. GraphRNN: Generating Realistic Graphs


  1. GraphRNN的优点在于它不需要任何inductive bias assumptions,就可以直接实现图生成任务。


  1. GraphRNN的思想:sequentially增加节点和边,最终生成一张图。如图所示:

image.png


  1. 将图建模为序列:

给定图 G  及其对应的node ordering π ,我们可以将其唯一映射为一个node and edge additions的序列 S π 如图所示,序列 S π  的每个元素都是加一个节点和这个节点与之前节点连接的边:

image.png

S π 是一个sequence的sequence,有两个级别:节点级别每次添加一个节点,边级别每次添加新节点与之前节点之间的边。


节点级别:

image.png


节点级别的每一步是一个边级别的序列:每一个元素是是否与该节点添加一条边,即形成一个如图所示的0-1变量序列:

image.png

这里的node ordering是随机选的,随后我们会讨论这一问题。

如图所示,每一次是生成邻接矩阵(黄色部分)中的一个节点(向右),每个节点生成一列边(向下):

image.png

这样我们就将图生成问题转化为序列生成问题。


我们需要建模两个过程:

(1) 生成一个新节点的state(节点级别序列)

(2) 根据新节点state生成它与之前节点相连的边(边级别序列)


方法:用Recurrent Neural Networks (RNNs) 建模这些过程

image.png


  1. RNN

RNNs是为序列数据所设计的,它sequentially输入序列数据以更新其hidden states,其hidden states包含已输入RNN的所有信息。更新过程由RNN cells实现。

图示流程:

image.png

image.png

image.png


  1. GraphRNN: Two levels of RNN

GraphRNN有一个节点级别RNN和一个边级别RNN,节点级别RNN生成边级别RNN的初始state,边级别RNN sequentially预测这个新节点与每一个之前的节点是否相连。

image.png

如图所示,边级别RNN预测新加入的节点是否与之前各点相连:

image.png

接下来将介绍如何用这个RNN生成序列。

image.png


  1. (1) 用RNN生成序列:用前一个cell的输出作为下一个cell的输入(x t + 1 = y t )。

(2) 初始化输入序列:用 start of sequence token (SOS) 作为初始输入。SOS常是一个全0或全1的向量。

(3) 结束生成任务:用 end of sequence token (EOS) 作为RNN额外输出。

如果输出EOS=0,则RNN继续生成;如果过输出EOS=1,则RNN停止生成。

image.png


模型如图所示:

这样的问题在于模型是确定的,但我们需要生成的是分布,所以需要模型具有随机性。

image.png

image.png

image.png


  1. RNN at Test Time

我们假设已经训练好了模型:

y t 是 x t + 1 是否为1这一遵从伯努利分布事件的概率,从而根据模型我们可以从输入输出 y t ,从而抽样出 x t + 1  。

如图所示:

image.png


9。 RNN at Training Time

在训练过程中,我们已知的数据就是序列 y (该节点与之前每一节点是否相连的0-1元素组成的序列)。

我们使用teacher forcing7 的方法,将每一个输入都从前一个节点的输出换成真实序列值,而用真实序列值与模型输出值来计算损失函数。如图所示:

image.png

image.png

image.png


  1. Putting Things Together

我们的计划是:


  • 增加一个新节点:跑节点RNN,用其每一步输出来初始化边RNN
  • 为新节点增加新边:跑边RNN,预测新节点是否与每一之前节点相连
  • 增加另一个新节点:用边RNN最后的hidden state来跑下一步的节点RNN
  • 停止图生成任务:如果边RNN在第一步输出EOS,则我们知道新节点上没有任何一条边,即不再与之前的图有连接,从而停止图生成过程。

image.png


  1. 训练过程

假设节点1已在图中,现在添加节点2:输入SOS到节点RNN中

image.png

边RNN预测节点2是否会与节点1相连:输入SOS到边RNN中,输出节点2是否会与节点1相连的概率0.5

image.png

用边RNN的hidden state更新节点RNN:

image.png

边RNN预测节点3是否会与节点1、2相连:输入SOS到边RNN中,输出节点3是否会与节点2相连的概率0.6;输入节点3与节点2不相连的真实值0到下一个cell中,输出节点3是否会与节点2相连的概率0.4:

image.png

用边RNN的hidden state更新节点RNN:

image.png

我们已知节点4不与任何之前节点相连,所以停止生成任务:输入SOS到边RNN中,没看懂这里是不是用teacher forcing强制停止的意思。

image.png

每一步我们都用真实值作为监督,如图所示,就跟右上角的图形式或邻接矩阵形式一样的真实值:

image.png

通过时间反向传播,随time step9 累积梯度,如图所示:

image.png


  1. 测试阶段
  • 根据预测出来的边分布抽样边
  • 用GraphRNN自己的预测来代替每一步输入(就类似训练阶段如果不用tearcher forcing的那种效果)

如图所示:

image.png


  1. GraphRNN总结:

通过生成一个2级序列来生成一张图,用RNN来生成序列。如图中所示,节点级别RNN向右预测,边级别RNN向下预测。

接下来我们要使RNN tractable,以及对其效果进行评估。

image.png


  1. tractability

在此前的模型中,每一个新节点都可以与其前任何一个节点相连,这需要太多步边生成了,需要产生一整个邻接矩阵(如上图所示),也有太多过长的边依赖了(不管已经有了多少个节点,新节点还要考虑是否与最前面的几个节点有边连接关系)。

如果我们使用随机的node ordering,那我们对每个新生成的节点就是都要考虑它与之前每一个节点是否有边(图中左下角所示):

image.png


  1. BFS

但是如果我们换成一种BFS的node ordering,那么在对每个边考虑它可能相连的之前节点的过程如图所示,我们只需要考虑在BFS时它同层和上一层的节点(因为再之前的节点跟它不会有邻居关系),即只需要考虑2步的节点而非 n − 1 n-1n−1 步的节点:

image.png

这样的好处有二:

(1) 减少了可能存在的node ordering数量(从 O ( n ! ) O(n!)O(n!) 减小到不同BFS ordering的数量)

(2) 减少了边生成的步数(因为不需要看之前所有节点了,只需要看一部分最近的节点即可)

image.png

在运行GraphRNN时仅需考虑该节点及其之前的一部分节点,如图所示:

image.png


  1. 对生成图的评估

我们的数据集是若干图,输出也是若干图,我们要求评估这两组图之间的相似性。有直接从视觉上观察其相似性和通过图统计指标来衡量其相似性两种衡量方式。

image.png


  • visual similarity

就直接看,能明显地发现在grid形式的图上,GraphRNN跟输入数据比传统图生成模型(主要用于生成网络而非这种grid图)要更像很多:

image.png


(图中Kronecker就是上节课讲的那个模型。其他baseline模型具体哪个对应哪个可以在 1这篇论文中找。这个图就是原论文中的插图)


即使在传统图生成模型应用的有社区的社交网络上,GraphRNN也表现很好,如图所示。这体现了GraphRNN的可泛化能力。

image.png

  • graph statistics similarity

我们想找到一些比目测更精确的比较方式,但直接在两张图的结构之间作比较很难(同构性检测是NP的),因此我们选择比较图统计指标。

典型的图统计指标包括:

(1) degree distribution (Deg.)

(2) clustering coefficient distribution (Clus.)

(3) orbit count statistics 11

注意:每个图统计指标都是一个概率分布。

image.png

所以我们一要比较两种图统计指标(两个概率分布),解决方法是earth mover distance (EMD);二要比较两个图统计指标的集合(两个概率分布的集合),解决方法是基于EMD的maximum mean discrepancy (MMD)。

image.png

  • earth mover distance (EMD)

用于比较两个分布之间的相似性。在直觉上就是衡量需要将一种分布编程另一种分布所需要移动的最小“泥土量”(面积)。总之这里有个公式,但是我也没仔细看具体怎么搞的。或许可以参考一下EMD的英文维基百科Earth mover’s distance - Wikipedia,以后有缘可以学习:

image.png

  • maximum mean discrepancy (MMD)

基于元素相似性,比较集合相似性:使用L2距离,对每个元素用EMD计算距离,然后用L2距离计算MMD。

image.png

呃但是这个公式我委实是没有看懂:

image.png

……什么东西啊这是?

对图生成结果的评估:

image.png

计算举例:通过计算原图域生成图之前在clustering coefficient distribution上的区别,我们发现GraphRNN是表现最好的(即最相似的)。

image.png


4. Application of Deep Graph Generative Models


本节主要介绍深度图生成模型在药物发现领域的应用GCPN2。


  1. 药物发现领域的问题是:我们如何学习一个模型,使其生成valid、真实的分子,且具有优化过的某一属性得分(如drug-likeness或可溶性等)?

image.png


  1. 这种生成任务就是goal-directed graph generation:

① 优化一个特定目标得分(high scores),如drug-likeness

② 遵从内蕴规则(valid),如chemical validity rules

③ 从示例中学习(realistic),如模仿一个分子图数据集

image.png


  1. 这一任务的难点在于需要在机器学习中引入黑盒:像drug-likeness这种受物理定律决定的目标是我们不可知的。

image.png


  1. 我们的解决思路是使用强化学习的思想

强化学习是一个机器学习agent观察环境environment,采取行动action来与环境互动interact,收到正向或负面的反馈reward,根据反馈从这一回环之中进行学习。回环如图所示。

其核心思想在于agent是直接从环境这一对agent的黑盒中进行学习的。

image.png


  1. 我们的解决方法是GCPN:graph convolutional policy network

结合了图表示学习和强化学习

核心思想:

  • GNN捕获图结构信息
  • 强化学习指导导向预期目标的图生成过程
  • 有监督训练模拟给定数据集的样例

image.png

  1. GCPN vs GraphRNN
  • 共同点:

sequentially生成图

模仿给定的图数据集

  • 主要差异:
  1. GCPN用GNN来预测图生成行为

优势:GNN比RNN更具有表现力

劣势:GNN比RNN更耗时(但是分子一般都是小图,所以我们负担得起这个时间代价)

  1. GCPN使用RL来直接生成符合我们目标的图。RL使goal-directed graph generation成为可能。

image.png


  1. sequential graph generation

GraphRNN:基于RNN hidden states(捕获至此已生成图部分的信息)预测图生成行为。

image.png

image.png


  1. GCPN概览

如图所示,首先插入节点5,然后用GNN预测节点5会与哪些节点相连,抽样边(action),检验其化学validity,计算reward。这个具体流程其实我也妹搞明白,强化学习这部分我就不太懂。以后有缘再仔细研究。

image.png


  1. 我们如何设置reward?

我们设置两种reward:

一种是step reward,学习执行valid action:每一步对valid action分配小的正反馈。

一种是final reward,优化预期属性:在最后对高预期属性分配正反馈。

reward=final reward + step reward

image.png


  1. 训练过程分两部分:
  • 有监督训练:通过模仿给定被观测图的行为训练policy,用交叉熵梯度下降。(跟GraphRNN中的一样)
  • 强化学习训练:训练policy以优化反馈,使用standard policy gradient algorithm。这一步我也不懂,它反正说可以参考CS234等强化学习课程来了解这部分。以后有缘再了解吧。

image.png

image.png


  1. GCPN实验结果

在logP和QED这些医药上要优化的指标上都表现很好:

image.png


constrained optimization / complete任务:编辑给定分子,在几步之后就能达到高属性得分(如在以logP作为罚项的基础上,提升辛醇的可溶性):

image.png


5. 本章总结


  1. 复杂图可以用深度学习通过sequential generation成功生成。
  2. 图生成决策的每一步都基于hidden state。

hidden state可以是隐式的向量表示(因为RNN的中间过程都在hidden state里面,所以说是隐式的),由RNN解码;也可以是显式的中间生成图,由GCN解码。

  1. 可以实现的任务包括模仿给定的图数据集和往给定目标优化图。

image.png

相关文章
|
7月前
|
机器学习/深度学习 供应链 算法
机器学习课程学习随笔
机器学习课程学习随笔
|
7月前
|
机器学习/深度学习 数据可视化 PyTorch
零基础入门语义分割-地表建筑物识别 Task5 模型训练与验证-学习笔记
零基础入门语义分割-地表建筑物识别 Task5 模型训练与验证-学习笔记
510 2
|
6月前
|
机器学习/深度学习 搜索推荐 PyTorch
【机器学习】图神经网络:深度解析图神经网络的基本构成和原理以及关键技术
【机器学习】图神经网络:深度解析图神经网络的基本构成和原理以及关键技术
1374 2
|
7月前
|
机器学习/深度学习 算法 图计算
图机器学习入门:基本概念介绍
图机器学习是机器学习的分支,专注于处理图形结构数据,其中节点代表实体,边表示实体间关系。本文介绍了图的基本概念,如无向图与有向图,以及图的性质,如节点度、邻接矩阵。此外,还讨论了加权图、自循环、多重图、双部图、异构图、平面图和循环图。图在描述数据关系和特征方面具有灵活性,为机器学习算法提供了丰富的结构信息。
189 0
|
7月前
|
机器学习/深度学习 人工智能 算法
机器学习的魔法(一)从零开始理解吴恩达的精炼笔记
机器学习的魔法(一)从零开始理解吴恩达的精炼笔记
|
7月前
|
机器学习/深度学习 数据可视化 算法
【学习打卡04】可解释机器学习笔记之Grad-CAM
【学习打卡04】可解释机器学习笔记之Grad-CAM
|
7月前
|
机器学习/深度学习
Coursera 吴恩达Machine Learning(机器学习)课程 |第五周测验答案(仅供参考)
Coursera 吴恩达Machine Learning(机器学习)课程 |第五周测验答案(仅供参考)
|
7月前
|
机器学习/深度学习 人工智能 文字识别
【学习打卡03】可解释机器学习笔记之CAM类激活热力图
【学习打卡03】可解释机器学习笔记之CAM类激活热力图
|
7月前
|
机器学习/深度学习 存储 数据可视化
【学习打卡02】可解释机器学习笔记之ZFNet
【学习打卡02】可解释机器学习笔记之ZFNet
|
机器学习/深度学习 编解码 计算机视觉
Python机器学习和图像处理学习笔记
Python机器学习和图像处理学习笔记

热门文章

最新文章