PyG学习笔记2-CREATING MESSAGE PASSING NETWORKS

简介: PyG学习笔记2-CREATING MESSAGE PASSING NETWORKS

将卷积运算符推广到不规则域通常表示为 邻域聚合或 消息传递方案。通过表示层中节点的节点特征和表示节点之间的(可选)边缘特征。

image.png

where ◻ denotes a differentiable, permutation invariant function, e.g., sum, mean or max, and γ and ϕ denote differentiable functions such as MLPs (Multi Layer Perceptrons).


MessagePassing基类


PyG 提供了MessagePassing基类,它通过自动处理消息传播来帮助创建此类消息传递图神经网络。用户只需要定义函数,即 message()和 ,即 update(),以及要使用的聚合方案,即, 或 。ϕγaggr="add"``aggr="mean"``aggr="max"


具体方法有:


MessagePassing(aggr=“add”, flow=“source_to_target”, node_dim=-2):定义要使用的聚合方案 (“add”“mean”“max”) 和消息传递的流向 (“source_to_target” “target_to_source” node_dim)。此外,该属性还指示要沿哪个轴传播。

MessagePassing.propagate(edge_index, size=None, **kwargs):开始传播消息的初始调用。接收构建消息和更新节点嵌入所需的边缘索引和所有其他数据。propagate()不仅限于交换形状的平方邻接矩阵中的消息,还可以通过作为附加参数传递来交换形状的一般稀疏赋值矩阵(例如,二分图)中的消息。如果设置为None,则假定赋值矩阵是一个方阵。对于具有两组独立节点和索引的二分图,并且每个集合都保存自己的信息,可以通过将信息传递为元组来标记此拆分,例如:[N, N][N, M]size=(N, M)``x=(x_N, x_M)

MessagePassing.message(…):构造到节点的消息,类似于 if 和 if 中的每个边。可以采用最初传递给 的任何参数。此外,传递给 的张量可以映射到相应的节点,并通过追加或附加到变量名称

MessagePassing.update(aggr_out, …):将节点嵌入类比更新为每个节点。将聚合的输出作为第一个参数,以及最初传递给propagate()的任何参数。


实现GCN


GCN在数学上定义为

image.png

其中相邻节点特征首先由权重矩阵变换,按其度数归一化,最后总结。此公式可分为以下步骤:Θ


1.将自循环添加到邻接矩阵。

2.线性变换节点特征矩阵。

3.计算规范化系数。

4.规范化 中的节点功能。ϕ

5.对相邻节点要素进行求和(聚合)。"add"


步骤 1-3 通常在消息传递发生之前计算。可以使用MessagePassing基类轻松处理步骤 4-5。完整层实现如下所示:

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
#实现GCN
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='add')  
        # 使用'add'聚合
        self.lin = torch.nn.Linear(in_channels, out_channels)
    def forward(self, x, edge_index):
        # forward 中输入 x 与 edge_index
        # x shape: [N, in_channels]
        # edge_index shape [2, E]
        # Step 1: 将自循环添加到邻接矩阵.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        # Step 2: 线性变换节点特征矩阵.
        x = self.lin(x)
        # Step 3: 计算规范化系数.也就是前后两个矩阵
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        # Step 4-5: 对相邻节点要素进行求和
        return self.propagate(edge_index, x=x, norm=norm)
    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: 乘以正则化项.
        return norm.view(-1, 1) * x_j


GCN继承自MessagePassing 和propagation。


在这里,我们首先使用torch_geometric.utils.add_self_loops()函数(步骤 1)将自循环添加到边缘索引中,并通过调用torch.nn.Linear实例(步骤 2)线性转换节点特征。


归一化系数由每个节点的节点度数导出,每个节点的节点度数被转换为每个边。结果保存在形状的张量中(步骤 3)


然后,我们继续调用propagate(),它在内部调用message()、 aggregate()和update()函数。作为消息传播的附加参数,我们传递节点嵌入和规范化系数 。


在message()函数中,我们需要通过规范化相邻节点特征。这里,表示提升的张量,它包含每个边的源节点特征,即每个节点的邻居。节点功能可以通过追加或附加到变量名称来自动解除。事实上,任何张量都可以以这种方式转换,只要它们包含源或目标节点特征。


初始化和调用它非常简单:

conv = GCNConv(16, 32)
x = conv(x, edge_index)


实现Edge Convolution


来自论文:


用来处理图形或点云,并在数学上定义为

image.png


其中表示 MLP。与GCN层类似,我们可以使用MessagePassing类来实现

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super().__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))
    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]
        return self.propagate(edge_index, x=x)
    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]
        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)


在message()函数中,我们用于转换每个边的目标节点特征和相对源节点特征


Edge Convolution实际上是一种动态卷积,它使用要素空间中最近的邻居重新计算每个图层的图形。幸运的是,PyG附带了一个名为torch_geometric.nn.pool.knn_graph()的GPU加速批处理k-NN图形生成方法:

from torch_geometric.nn import knn_graph
class DynamicEdgeConv(EdgeConv):
    def __init__(self, in_channels, out_channels, k=6):
        super().__init__(in_channels, out_channels)
        self.k = k
    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return super().forward(x, edge_index)


在这里,knn_graph()计算一个最近邻图,该图进一步用于调用EdgeConv的方法。forward()


这给我们留下了一个干净的接口来初始化和调用这个层:

conv = DynamicEdgeConv(3, 128, k=6)
x = conv(x, batch)
目录
相关文章
|
3月前
|
算法 数据挖掘
文献解读-Consistency and reproducibility of large panel next-generation sequencing: Multi-laboratory assessment of somatic mutation detection on reference materials with mismatch repair and proofreading deficiency
Consistency and reproducibility of large panel next-generation sequencing: Multi-laboratory assessment of somatic mutation detection on reference materials with mismatch repair and proofreading deficiency,大panel二代测序的一致性和重复性:对具有错配修复和校对缺陷的参考物质进行体细胞突变检测的多实验室评估
32 6
文献解读-Consistency and reproducibility of large panel next-generation sequencing: Multi-laboratory assessment of somatic mutation detection on reference materials with mismatch repair and proofreading deficiency
|
机器学习/深度学习 算法 Oracle
Paper:《“Why Should I Trust You?“: Explaining the Predictions of Any Classifier》翻译与解读
Paper:《“Why Should I Trust You?“: Explaining the Predictions of Any Classifier》翻译与解读
|
7月前
解决Error:All flavors must now belong to a named flavor dimension. Learn more at https://d.android.com
解决Error:All flavors must now belong to a named flavor dimension. Learn more at https://d.android.com
133 5
|
机器学习/深度学习 移动开发 自然语言处理
DEPPN:Document-level Event Extraction via Parallel Prediction Networks 论文解读
当在整个文档中描述事件时,文档级事件抽取(DEE)是必不可少的。我们认为,句子级抽取器不适合DEE任务,其中事件论元总是分散在句子中
144 0
DEPPN:Document-level Event Extraction via Parallel Prediction Networks 论文解读
|
机器学习/深度学习 自然语言处理 数据挖掘
UnifiedEAE: A Multi-Format Transfer Learning Model for Event Argument Extraction via Variational论文解读
事件论元抽取(Event argument extraction, EAE)旨在从文本中抽取具有特定角色的论元,在自然语言处理中已被广泛研究。
97 0
|
机器学习/深度学习 自然语言处理 算法
ACL 2022:Graph Pre-training for AMR Parsing and Generation
抽象语义表示(AMR)以图形结构突出文本的核心语义信息。最近,预训练语言模型(PLM)分别具有AMR解析和AMR到文本生成的高级任务。
165 0
|
机器学习/深度学习 自然语言处理 数据可视化
M2E2: Cross-media Structured Common Space for Multimedia Event Extraction 论文解读
我们介绍了一个新的任务,多媒体事件抽取(M2E2),旨在从多媒体文档中抽取事件及其参数。我们开发了第一个基准测试
114 0
|
机器学习/深度学习 自然语言处理 数据可视化
EventGraph:Event Extraction as Semantic Graph Parsing 论文解读
事件抽取涉及到事件触发词和相应事件论元的检测和抽取。现有系统经常将事件抽取分解为多个子任务,而不考虑它们之间可能的交互。
85 0
|
数据挖掘
MUSIED: A Benchmark for Event Detection from Multi-Source Heterogeneous Informal Texts 论文解读
事件检测(ED)从非结构化文本中识别和分类事件触发词,作为信息抽取的基本任务。尽管在过去几年中取得了显著进展
73 0
|
机器学习/深度学习 算法 数据挖掘
【多标签文本分类】Improved Neural Network-based Multi-label Classification with Better Initialization ……
【多标签文本分类】Improved Neural Network-based Multi-label Classification with Better Initialization ……
140 0
【多标签文本分类】Improved Neural Network-based Multi-label Classification with Better Initialization ……