一、消息传递范式
聚合函数和更新函数。
1.1 内置函数和消息传递API
(1)API属性介绍
消息函数:接受一个参数 edges,这是一个 dgl.EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。edges有三个成员属性:src、dst和data,分别用于访问源节点、目标节点和边的特征。
聚合函数:接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 sum、max、min 、mean等。聚合函数一般有2个参数,它们的类型都是字符串:
一个用于指定mailbox中的字段名;
一个用于指示目标节点特征的字段名,例如dgl.function.sum('m', 'h')等价于如下所示的对接收到消息求和的用户定义函数:
import torch def reduce_func(nodes): return {'h': torch.sum(nodes.mailbox['m'], dim=1)}
更新函数:接受一个如上所述的参数 nodes。此函数对 聚合函数 的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征。
(2)内置&自定义函数
(1)命名空间dfl.function中实现了常用的(内置的)消息函数和聚合函数(能够自动处理维度广播),当然也可以自定义函数。
(2)(自定义)内置消息函数:可以是一元函数(dgl支持copy函数),也支持二元函数(dgl支持add、sub、mul、div、dot函数):
消息的内置函数的命名约定是u表示源节点,v表示目标节点,e表示边。
这些函数的参数是字符串,指示相应节点和边的输入和输出特征字段名。
ex:要对源节点的hu特征和目标节点的hv特征求和,然后将结果保存在边的he特征上:dgl.function.u_add_v('hu', 'hv', 'he');
- 如下自定义消息函数和内置函数相同:
def message_func(edges): return {'he': edges.src['hu'] + edges.dst['hv']}
(3)在DGL中,也可以在不涉及消息传递的情况下,通过 apply_edges()
单独调用逐边计算。
apply_edges()
的参数是一个消息函数。- 在默认情况下,这个接口将更新所有的边。
- 例如:
import dgl.function as fn graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
(4)消息传递的高级APIupdate_all()
它在单个API调用里合并了消息生成、消息聚合和节点特征更新,为从整体上进行系统优化提供了空间。
update_all()的参数:一个消息、聚合、更新函数(可选,也可以在外面操作,dgl不推荐在update_all中指定更新函数)。
更新函数是一个可选择的参数,可以不使用,而是在update_all执行完后直接对节点特征进行操作;
因为更新函数通常可用纯张量操作实现,所以DGL不推荐在update_all中指定更新函数,如函数:
def updata_all_example(graph): # 在graph.ndata['ft']中存储结果 graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft')) # 在update_all外调用更新函数 final_ft = graph.ndata['ft'] * 2 return final_ft
ex:graph.update_all(fn.u_mul_e('ft', 'a', 'm')将源节点特征tf和边特征a相乘生成消息m,fn.sum('m', 'ft')再对所有消息求和来更新节点特征ft,再乘2后得到最终结果final_ft。调用后,中间消息m会被清除。
1.2 编写高效的消息传递代码
关于dgl内置函数是如何优化消息传递的内存消耗和计算速度的, 详见文字描述: DGL官方文档 ; 总结来说主要是合并内核, 并行逐边运算, 减少点边拷贝等; 如update_all()函数就是一个效率很高的接口; 如果确实需要使用apply_edges()函数在边上保存消息, 则内存占用会非常大;
(1)一个通过对节点特征降维来减少消息维度的示例:
拼接源节点与目标节点特征, 然后应用一个线性层: W × ( u ∣ ∣ v ) W\times (u||v)W×(u∣∣v);
这样源节点与目标节点特征维数较高, 而线性层输出维数较低;
代码示例:
import torch import torch.nn as nn linear = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim * 2))) def concat_message_function(edges): return {'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])} g.apply_edges(concat_message_function) g.edata['out'] = g.edata['cat_feat'] * linear
import dgl.function as fn linear_src = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim))) linear_dst = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim))) out_src = g.ndata['feat'] * linear_src out_dst = g.ndata['feat'] * linear_dst g.srcdata.update({'out_src': out_src}) g.dstdata.update({'out_dst': out_dst}) g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))
这两种方法数学上等价, 但后一种方法更加高效, 因为无需再边上保存feat_src
和feat_dst
, 空间占用小, 另外加法可以直接用内置函数u_add_v
进行优化, 内置函数的效率一般比自定义函数要高。
1.3 在图的一部分上进行消息传递
如果用户只想更新图中部分节点,先将想处理的节点编号创建一个子图,然后对其调用update_all()(这也是小批量处理中的常见用法)。
nid = [0, 2, 3, 6, 7, 9] sg = g.subgraph(nid) sg.update_all(message_func, reduce_func, apply_node_func)
1.4 在消息传递中使用边的权重
常见的GNN建模做法:在消息聚合前使用边的权重,如GAT和一些GCN的变种。dgl的处理:
将权重存为边的特征;
在消息函数中用边的特征和源节点的特征相乘。
ex:假定下面的权重eweight是一个形状为(E, *)的张量,E是边的数量。权重存为边的特征,即eweight被用作边的权重(通常是一个标量)。
import dgl.function as fn # 假定eweight是一个形状为(E, *)的张量,E是边的数量。 graph.edata['a'] = eweight graph.update_all(fn.u_mul_e('ft', 'a', 'm'), fn.sum('m', 'ft'))
1.5 在异构图上进行消息传递
本质上异构图的消息传递与同构图并没有太大区别,异构图上的消息传递可以分为两个部分:
对每个关系计算和聚合消息。
对每个结点聚合来自不同关系的消息。
DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None):
etype_dict: dict类型, 键为一种关系, 值为这种关系对应的update_all()的参数;
cross_reducer: str类型, 表示跨类型整合函数, 来指定整合不同关系聚合结果的方式, 可以是sum, min, max, mean, stack中之一;
在DGL中,对异构图进行消息传递的接口是 multi_update_all()。 multi_update_all() 接受一个字典。这个字典的每一个键值对里,键是一种关系, 值是这种关系对应 update_all() 的参数。 multi_update_all() 还接受一个字符串来表示跨类型整合函数,来指定整合不同关系聚合结果的方式。 这个整合方式可以是 sum、 min、 max、 mean 和 stack 中的一个。
import dgl.function as fn for c_etype in G.canonical_etypes: srctype, etype, dsttype = c_etype Wh = self.weight[etype](feat_dict[srctype]) # 把它存在图中用来做消息传递 G.nodes[srctype].data['Wh_%s' % etype] = Wh # 指定每个关系的消息传递函数:(message_func, reduce_func). # 注意结果保存在同一个目标特征“h”,说明聚合是逐类进行的。 funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h')) # 将每个类型消息聚合的结果相加。 G.multi_update_all(funcs, 'sum') # 返回更新过的节点特征字典 return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}
小结消息传递的流程:
消息函数(message function):传递消息的目的是将节点计算时需要的信息传递给它,因此对每条边来说,每个源节点将会将自身的Embedding(e.src.data)和边的Embedding(edge.data)传递到目的节点;对于每个目的节点来说,它可能会受到多个源节点传过来的消息,它会将这些消息存储在"邮箱"中。
汇聚函数(reduce function):汇聚函数的目的是根据邻居传过来的消息更新跟新自身节点Embedding,对每个节点来说,它先从邮箱(v.mailbox[‘m’])中汇聚消息函数所传递过来的消息(message),并清空邮箱(v.mailbox[‘m’])内消息;然后该节点结合汇聚后的结果和该节点原Embedding,更新节点Embedding。
更新: message函数的参数是边,包括源节点,目标节点的特征信息,处理完的数据放置到节点的mailbox中。 聚合函数reduce_function或 apply_node函数作用于节点本身,即传入的参数是节点信息和节点的邮箱信息。
二、不带边权重的例子
2.1 消息传递框架
DGL遵循Gilmer等人提出的消息8传递框架,很多GNN模型能符合如下框架:
2.2 GraphSAGE的消息传递
如GraphSAGE可表示为:
我们可以看到消息传递是定向(有方向)的:从一个节点u发送到另一个节点v的消息不一定与从节点v发送到相反方向的节点u的消息相同。
DGL提供了GraphSAGE的实现dgl.nn.SAGEConv
。
import dgl.function as fn class SAGEConv(nn.Module): """Graph convolution module used by the GraphSAGE model. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(SAGEConv, self).__init__() # A linear submodule for projecting the input and neighbor feature to the output. self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h): """Forward computation Parameters ---------- g : Graph The input graph. h : Tensor The input node feature. """ with g.local_scope(): g.ndata['h'] = h # update_all is a message passing API. g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N')) h_N = g.ndata['h_N'] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total)
上述代码中的核心部分是g.update_all函数,该函数收集并平均相邻特征。这里有三个概念:
消息函数fn.copy_u('h','m'),它将名为“h”的节点特征复制为发送给邻居的消息
聚合函数fn.mean('m', 'h_N'),该函数对所有接收到的消息中名为’m’的信息进行平均,并将结果保存为新的节点特征’h_N’
update_all让DGL触发所有节点和边的消息函数和聚合函数
2.3 堆叠网络
然后我们可以堆叠自己的GraphSAGE卷积层以构成多层GraphSAGE网络:
import torch.nn.functional as F class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = GraphSAGE(in_feats, h_feats) self.conv2 = GraphSAGE(h_feats, num_classes) def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = F.relu(h) h = self.conv2(g, h) return h
2.4 训练网络
import dgl.data dataset = dgl.data.CoraGraphDataset() g = dataset[0] def train(g, model): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) all_logits = [] best_val_acc = 0 best_test_acc = 0 features = g.ndata['feat'] labels = g.ndata['label'] train_mask = g.ndata['train_mask'] val_mask = g.ndata['val_mask'] test_mask = g.ndata['test_mask'] for e in range(200): # Forward logits = model(g, features) # Compute prediction pred = logits.argmax(1) # Compute loss # Note that we should only compute the losses of the nodes in the training set, # i.e. with train_mask 1. loss = F.cross_entropy(logits[train_mask], labels[train_mask]) # Compute accuracy on training/validation/test train_acc = (pred[train_mask] == labels[train_mask]).float().mean() val_acc = (pred[val_mask] == labels[val_mask]).float().mean() test_acc = (pred[test_mask] == labels[test_mask]).float().mean() # Save the best validation accuracy and the corresponding test accuracy. if best_val_acc < val_acc: best_val_acc = val_acc best_test_acc = test_acc # Backward optimizer.zero_grad() loss.backward() optimizer.step() all_logits.append(logits.detach()) if e % 5 == 0: print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format( e, loss, val_acc, best_val_acc, test_acc, best_test_acc)) model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes) train(g, model)
2.5 结果
Using backend: pytorch NumNodes: 2708 NumEdges: 10556 NumFeats: 1433 NumClasses: 7 NumTrainingSamples: 140 NumValidationSamples: 500 NumTestSamples: 1000 Done loading data from cached files. In epoch 0, loss: 1.951, val acc: 0.114 (best 0.114), test acc: 0.103 (best 0.103) In epoch 5, loss: 1.900, val acc: 0.290 (best 0.292), test acc: 0.278 (best 0.277) In epoch 10, loss: 1.790, val acc: 0.462 (best 0.462), test acc: 0.435 (best 0.435) In epoch 15, loss: 1.614, val acc: 0.502 (best 0.502), test acc: 0.489 (best 0.489) In epoch 20, loss: 1.372, val acc: 0.548 (best 0.548), test acc: 0.529 (best 0.529) In epoch 25, loss: 1.087, val acc: 0.592 (best 0.592), test acc: 0.591 (best 0.591) In epoch 30, loss: 0.798, val acc: 0.650 (best 0.650), test acc: 0.639 (best 0.639) In epoch 35, loss: 0.547, val acc: 0.690 (best 0.690), test acc: 0.682 (best 0.682) In epoch 40, loss: 0.358, val acc: 0.710 (best 0.710), test acc: 0.721 (best 0.721) In epoch 45, loss: 0.230, val acc: 0.736 (best 0.736), test acc: 0.734 (best 0.734) In epoch 50, loss: 0.149, val acc: 0.738 (best 0.738), test acc: 0.743 (best 0.744) In epoch 55, loss: 0.099, val acc: 0.740 (best 0.740), test acc: 0.744 (best 0.743) In epoch 60, loss: 0.068, val acc: 0.742 (best 0.742), test acc: 0.743 (best 0.745) In epoch 65, loss: 0.048, val acc: 0.734 (best 0.742), test acc: 0.749 (best 0.745) In epoch 70, loss: 0.036, val acc: 0.736 (best 0.742), test acc: 0.753 (best 0.745) In epoch 75, loss: 0.028, val acc: 0.734 (best 0.742), test acc: 0.755 (best 0.745) In epoch 80, loss: 0.023, val acc: 0.738 (best 0.742), test acc: 0.757 (best 0.745) In epoch 85, loss: 0.019, val acc: 0.738 (best 0.742), test acc: 0.758 (best 0.745) In epoch 90, loss: 0.017, val acc: 0.742 (best 0.742), test acc: 0.756 (best 0.745) In epoch 95, loss: 0.015, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745) In epoch 100, loss: 0.013, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745) In epoch 105, loss: 0.012, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745) In epoch 110, loss: 0.011, val acc: 0.742 (best 0.742), test acc: 0.753 (best 0.745) In epoch 115, loss: 0.010, val acc: 0.742 (best 0.742), test acc: 0.753 (best 0.745) In epoch 120, loss: 0.009, val acc: 0.742 (best 0.742), test acc: 0.754 (best 0.745) In epoch 125, loss: 0.008, val acc: 0.742 (best 0.742), test acc: 0.754 (best 0.745) In epoch 130, loss: 0.008, val acc: 0.742 (best 0.742), test acc: 0.752 (best 0.745) In epoch 135, loss: 0.007, val acc: 0.742 (best 0.742), test acc: 0.752 (best 0.745) In epoch 140, loss: 0.007, val acc: 0.744 (best 0.744), test acc: 0.751 (best 0.751) In epoch 145, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.751 (best 0.751) In epoch 150, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.749 (best 0.751) In epoch 155, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.750 (best 0.751) In epoch 160, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.750 (best 0.751) In epoch 165, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751) In epoch 170, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751) In epoch 175, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.752 (best 0.751) In epoch 180, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751) In epoch 185, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751) In epoch 190, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.754 (best 0.751) In epoch 195, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.754 (best 0.751)
三、带边权重的例子
这里需要改的两个位置,g.update_all
的2个参数,以及在Model
中要设置边权重的传参。其他就没啥变化了。即使用带权平均聚合邻居表示,edata
成员可以保存边权重(特征),这些特征也可以参与消息传递。
# data可以包含边特征信息,同时传递 class WeightedSAGEConv(nn.Module): """ in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(WeightedSAGEConv, self).__init__() # 将input和邻近节点特征映射到outpu线性子模块 self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h, w): """ g : Graph The input graph. h : Tensor The input node feature. w : Tensor The edge weight. """ with g.local_scope(): g.ndata['h'] = h # 加入边的权重,进行消息传递和更新 g.edata['w'] = w g.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N')) h_N = g.ndata['h_N'] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total) # 因为这个数据集中的图没有边的权值, # 所以我们在模型的 forward 函数中手动将所有边的权值赋给1。 class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = WeightedSAGEConv(in_feats, h_feats) self.conv2 = WeightedSAGEConv(h_feats, num_classes) def forward(self, g, in_feat): # 3个参数,(g, h, w)即图,点特征,边权重 h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device)) h = F.relu(h) # 设置所有边的权重为1 h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device)) return h
四、DGL按照优先级高低排序推荐做法
直接调用dgl.nn模块;
使用dgl.nn.functional内置方法,适合一些简单操作,如为每个节点计算softmax;
使用update_all,内置的消息函数和聚合函数;
使用用户自定义的消息(message)函数和聚合(reduce)函数。
五、用户自定义函数
DGL允许用户自定义消息函数和聚合函数以获得最大的表达能力。以下是一个用户定义的消息函数,它等价于fn.u_mul_e('h', 'w', 'm')。
def u_mul_e_udf(edges): return {"m": edges.src["h"] * edges.data["w"]}
参数edges共有三个成员:src,data和dst,分别代表所有边的源节点特征,边特征和目标节点特征。
也可以编写自己的聚合函数。例如,下面的函数相当于内置的fn.sum(‘m’, ‘h’)函数,它对传入的消息求和:
def sum_udf(nodes): return {"h": nodes.mailbox["m"].sum(dim=1)} # dim=1,按行求和
总之,DGL将按节点的度数对节点进行分组,对于每个组DGL将传入的消息沿着第2维度(按行)进行堆叠,然后沿第2个维度执行缩减(reduce)以聚合消息。