1 DGL NN模块的构造函数
构造函数完成以下几个任务:
- 设置选项。
- 注册可学习的参数或者子模块。
- 初始化参数。
import torch.nn as nn from dgl.utils import expand_as_pair class SAGEConv(nn.Module): def __init__(self, in_feats, out_feats, aggregator_type, bias=True, norm=None, activation=None): super(SAGEConv, self).__init__() self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats) self._out_feats = out_feats self._aggre_type = aggregator_type self.norm = norm self.activation = activation
在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。
对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。
除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type
)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。
常用的聚合类型包括 mean
、 sum
、 max
和 min
。一些模块可能会使用更加复杂的聚合函数,比如 lstm
。
上面代码里的 norm
是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化:
# 聚合类型:mean、max_pool、lstm、gcn if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']: raise KeyError('Aggregator type {} not supported.'.format(aggregator_type)) if aggregator_type == 'max_pool': self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats) if aggregator_type == 'lstm': self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True) if aggregator_type in ['mean', 'max_pool', 'lstm']: self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias) self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias) self.reset_parameters()
注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linear
、 nn.LSTM
等。
构造函数的最后调用了 reset_parameters()
进行权重初始化。
def reset_parameters(self): """重新初始化可学习的参数""" gain = nn.init.calculate_gain('relu') if self._aggre_type == 'max_pool': nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain) if self._aggre_type == 'lstm': self.lstm.reset_parameters() if self._aggre_type != 'gcn': nn.init.xavier_uniform_(self.fc_self.weight, gain=gain) nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)
2 编写DGL NN模块的forward函数
在NN模块中, forward()
函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,
DGL NN模块额外增加了1个参数 :class:dgl.DGLGraph
。forward()
函数的内容一般可以分为3项操作:
- 检测输入图对象是否符合规范。
- 消息传递和聚合。
- 聚合后,更新特征作为输出。
下文展示了SAGEConv示例中的 forward()
函数。
输入图对象的规范检测
def forward(self, graph, feat): with graph.local_scope(): # 指定图类型,然后根据图类型扩展输入特征 feat_src, feat_dst = expand_as_pair(feat, graph)
forward()
函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。
比如在 :class:~dgl.nn.pytorch.conv.GraphConv
等conv模块中,DGL会检查输入图中是否有入度为0的节点。
当1个节点入度为0时, mailbox
将为空,并且聚合函数的输出值全为0,
这可能会导致模型性能不佳。但是,在 :class:~dgl.nn.pytorch.conv.SAGEConv
模块中,被聚合的特征将会与节点的初始特征拼接起来,
forward()
函数的输出不会全为0。在这种情况下,无需进行此类检验。
DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(:ref:guide_cn-graph-heterogeneous
)和子图块(:ref:guide_cn-minibatch
)。
SAGEConv的数学公式如下:
源节点特征 feat_src
和目标节点特征 feat_dst
需要根据图类型被指定。
用于指定图类型并将 feat
扩展为 feat_src
和 feat_dst
的函数是 :meth:~dgl.utils.expand_as_pair
。
该函数的细节如下所示。
def expand_as_pair(input_, g=None): if isinstance(input_, tuple): # 二分图的情况 return input_ elif g is not None and g.is_block: # 子图块的情况 if isinstance(input_, Mapping): input_dst = { k: F.narrow_row(v, 0, g.number_of_dst_nodes(k)) for k, v in input_.items()} else: input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes()) return input_, input_dst else: # 同构图的情况 return input_, input_
对于同构图上的全图训练,源节点和目标节点相同,它们都是图中的所有节点。
在异构图的情况下,图可以分为几个二分图,每种关系对应一个。关系表示为 (src_type, edge_type, dst_dtype)
。
当输入特征 feat
是1个元组时,图将会被视为二分图。元组中的第1个元素为源节点特征,第2个元素为目标节点特征。
在小批次训练中,计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(block
)。
在区块创建的阶段,dst nodes
位于节点列表的最前面。通过索引 [0:g.number_of_dst_nodes()]
可以找到 feat_dst
。
确定 feat_src
和 feat_dst
之后,以上3种图类型的计算方法是相同的。
消息传递和聚合
import dgl.function as fn import torch.nn.functional as F from dgl.utils import check_eq_shape if self._aggre_type == 'mean': graph.srcdata['h'] = feat_src graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh')) h_neigh = graph.dstdata['neigh'] elif self._aggre_type == 'gcn': check_eq_shape(feat) graph.srcdata['h'] = feat_src graph.dstdata['h'] = feat_dst graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh')) # 除以入度 degs = graph.in_degrees().to(feat_dst) h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1) elif self._aggre_type == 'max_pool': graph.srcdata['h'] = F.relu(self.fc_pool(feat_src)) graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh')) h_neigh = graph.dstdata['neigh'] else: raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type)) # GraphSAGE中gcn聚合不需要fc_self if self._aggre_type == 'gcn': rst = self.fc_neigh(h_neigh) else: rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)
上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。
聚合后,更新特征作为输出
# 激活函数 if self.activation is not None: rst = self.activation(rst) # 归一化 if self.norm is not None: rst = self.norm(rst) return rst
forward()
函数的最后一部分是在完成消息聚合后更新节点的特征。
常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。
3 简单的图分类任务
在本教程中,我们将学习如何使用 DGL 执行图分类,这个例子的任务目标就是对下面显示的八种拓扑类型Grpah进行分类。
这里我们直接使用 DGL 中合成数据集 data.MiniGCDataset
。数据集有八种不同类型的图,每个类都有相同数量的图样本
from dgl.data import MiniGCDataset import matplotlib.pyplot as plt import networkx as nx # A dataset with 80 samples, each graph is # of size [10, 20] dataset = MiniGCDataset(80, 10, 20) graph, label = dataset[0] fig, ax = plt.subplots() nx.draw(graph.to_networkx(), ax=ax) ax.set_title('Class: {:d}'.format(label)) plt.show()
Using backend: pytorch
创建graph的批数据
image
import dgl import torch def collate(samples): # The input `samples` is a list of pairs # (graph, label). graphs, labels = map(list, zip(*samples)) batched_graph = dgl.batch(graphs) return batched_graph, torch.tensor(labels,dtype=torch.long)
构建Graph分类器
image
from dgl.nn.pytorch import GraphConv import torch.nn as nn import torch.nn.functional as F class Classifier(nn.Module): def __init__(self, in_dim, hidden_dim, n_classes): super(Classifier, self).__init__() self.conv1 = GraphConv(in_dim, hidden_dim) self.conv2 = GraphConv(hidden_dim, hidden_dim) self.classify = nn.Linear(hidden_dim, n_classes) def forward(self, g): # Use node degree as the initial node feature. For undirected graphs, the in-degree # is the same as the out_degree. h = g.in_degrees().view(-1, 1).float() # Perform graph convolution and activation function. h = F.relu(self.conv1(g, h)) h = F.relu(self.conv2(g, h)) g.ndata['h'] = h # Calculate graph representation by averaging all the node representations. hg = dgl.mean_nodes(g, 'h') return self.classify(hg)
import torch.optim as optim from torch.utils.data import DataLoader # Create training and test sets. trainset = MiniGCDataset(320, 10, 20) testset = MiniGCDataset(80, 10, 20) # Use PyTorch's DataLoader and the collate function # defined before. data_loader = DataLoader(trainset, batch_size=32, shuffle=True, collate_fn=collate) # Create model model = Classifier(1, 256, trainset.num_classes) loss_func = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) model.train() epoch_losses = [] for epoch in range(80): epoch_loss = 0 for iter, (bg, label) in enumerate(data_loader): prediction = model(bg) loss = loss_func(prediction, label) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.detach().item() epoch_loss /= (iter + 1) print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss)) epoch_losses.append(epoch_loss)
Epoch 0, loss 2.0010 Epoch 1, loss 1.9744 Epoch 2, loss 1.9551 Epoch 3, loss 1.9444 Epoch 4, loss 1.9318 Epoch 5, loss 1.9170 Epoch 6, loss 1.8928 Epoch 7, loss 1.8573 Epoch 8, loss 1.8212 Epoch 9, loss 1.7715 Epoch 10, loss 1.7152 Epoch 11, loss 1.6570 Epoch 12, loss 1.5885 Epoch 13, loss 1.5308 Epoch 14, loss 1.4719 Epoch 15, loss 1.4158 Epoch 16, loss 1.3515 Epoch 17, loss 1.2963 Epoch 18, loss 1.2417 Epoch 19, loss 1.1978 Epoch 20, loss 1.1698 Epoch 21, loss 1.1086 Epoch 22, loss 1.0780 Epoch 23, loss 1.0459 Epoch 24, loss 1.0192 Epoch 25, loss 1.0017 Epoch 26, loss 1.0297 Epoch 27, loss 0.9784 Epoch 28, loss 0.9486 Epoch 29, loss 0.9327 Epoch 30, loss 0.9133 Epoch 31, loss 0.9265 Epoch 32, loss 0.9177 Epoch 33, loss 0.9303 Epoch 34, loss 0.8666 Epoch 35, loss 0.8639 Epoch 36, loss 0.8474 Epoch 37, loss 0.8858 Epoch 38, loss 0.8393 Epoch 39, loss 0.8306 Epoch 40, loss 0.8204 Epoch 41, loss 0.8057 Epoch 42, loss 0.7998 Epoch 43, loss 0.7909 Epoch 44, loss 0.7840 Epoch 45, loss 0.7807 Epoch 46, loss 0.7882 Epoch 47, loss 0.7701 Epoch 48, loss 0.7612 Epoch 49, loss 0.7563 Epoch 50, loss 0.7430 Epoch 51, loss 0.7354 Epoch 52, loss 0.7357 Epoch 53, loss 0.7326 Epoch 54, loss 0.7249 Epoch 55, loss 0.7181 Epoch 56, loss 0.7146 Epoch 57, loss 0.7306 Epoch 58, loss 0.7143 Epoch 59, loss 0.7018 Epoch 60, loss 0.7130 Epoch 61, loss 0.7003 Epoch 62, loss 0.6977 Epoch 63, loss 0.7120 Epoch 64, loss 0.6979 Epoch 65, loss 0.7370 Epoch 66, loss 0.7223 Epoch 67, loss 0.6980 Epoch 68, loss 0.6891 Epoch 69, loss 0.6715 Epoch 70, loss 0.6736 Epoch 71, loss 0.6709 Epoch 72, loss 0.6583 Epoch 73, loss 0.6717 Epoch 74, loss 0.6683 Epoch 75, loss 0.6656 Epoch 76, loss 0.6477 Epoch 77, loss 0.6414 Epoch 78, loss 0.6442 Epoch 79, loss 0.6398
plt.title('cross entropy averaged over minibatches') plt.plot(epoch_losses) plt.show()
model.eval() # Convert a list of tuples to two lists test_X, test_Y = map(list, zip(*testset)) test_bg = dgl.batch(test_X) test_Y = torch.tensor(test_Y).float().view(-1, 1) probs_Y = torch.softmax(model(test_bg), 1) sampled_Y = torch.multinomial(probs_Y, 1) argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1) print('Accuracy of sampled predictions on the test set: {:.4f}%'.format( (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100)) print('Accuracy of argmax predictions on the test set: {:4f}%'.format( (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))
Accuracy of sampled predictions on the test set: 58.7500% Accuracy of argmax predictions on the test set: 62.500000%