图神经网络17-DGL实战:构建图神经网络(GNN)模块

本文涉及的产品
函数计算FC,每月15万CU 3个月
简介: 图神经网络17-DGL实战:构建图神经网络(GNN)模块

1 DGL NN模块的构造函数


构造函数完成以下几个任务:

  1. 设置选项。
  2. 注册可学习的参数或者子模块。
  3. 初始化参数。

import torch.nn as nn
    from dgl.utils import expand_as_pair
    class SAGEConv(nn.Module):
        def __init__(self,
                     in_feats,
                     out_feats,
                     aggregator_type,
                     bias=True,
                     norm=None,
                     activation=None):
            super(SAGEConv, self).__init__()
            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
            self._out_feats = out_feats
            self._aggre_type = aggregator_type
            self.norm = norm
            self.activation = activation


在构造函数中,用户首先需要设置数据的维度。对于一般的PyTorch模块,维度通常包括输入的维度、输出的维度和隐层的维度。


对于图神经网络,输入维度可被分为源节点特征维度和目标节点特征维度。

除了数据维度,图神经网络的一个典型选项是聚合类型(self._aggre_type)。对于特定目标节点,聚合类型决定了如何聚合不同边上的信息。

常用的聚合类型包括 meansummaxmin。一些模块可能会使用更加复杂的聚合函数,比如 lstm

上面代码里的 norm 是用于特征归一化的可调用函数。在SAGEConv论文里,归一化可以是L2归一化:

# 聚合类型:mean、max_pool、lstm、gcn
if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
     raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
 if aggregator_type == 'max_pool':
     self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
if aggregator_type == 'lstm':
      self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
if aggregator_type in ['mean', 'max_pool', 'lstm']:
      self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
      self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
      self.reset_parameters()


注册参数和子模块。在SAGEConv中,子模块根据聚合类型而有所不同。这些模块是纯PyTorch NN模块,例如 nn.Linearnn.LSTM 等。

构造函数的最后调用了 reset_parameters() 进行权重初始化。

def reset_parameters(self):
        """重新初始化可学习的参数"""
        gain = nn.init.calculate_gain('relu')
        if self._aggre_type == 'max_pool':
            nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        if self._aggre_type == 'lstm':
            self.lstm.reset_parameters()
        if self._aggre_type != 'gcn':
            nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)


2 编写DGL NN模块的forward函数


在NN模块中, forward() 函数执行了实际的消息传递和计算。与通常以张量为参数的PyTorch NN模块相比,

DGL NN模块额外增加了1个参数 :class:dgl.DGLGraphforward() 函数的内容一般可以分为3项操作:

  • 检测输入图对象是否符合规范。
  • 消息传递和聚合。
  • 聚合后,更新特征作为输出。

下文展示了SAGEConv示例中的 forward() 函数。

输入图对象的规范检测

def forward(self, graph, feat):
        with graph.local_scope():
         # 指定图类型,然后根据图类型扩展输入特征
         feat_src, feat_dst = expand_as_pair(feat, graph)


forward() 函数需要处理输入的许多极端情况,这些情况可能导致计算和消息传递中的值无效。

比如在 :class:~dgl.nn.pytorch.conv.GraphConv 等conv模块中,DGL会检查输入图中是否有入度为0的节点。

当1个节点入度为0时, mailbox 将为空,并且聚合函数的输出值全为0,

这可能会导致模型性能不佳。但是,在 :class:~dgl.nn.pytorch.conv.SAGEConv 模块中,被聚合的特征将会与节点的初始特征拼接起来,

forward() 函数的输出不会全为0。在这种情况下,无需进行此类检验。

DGL NN模块可在不同类型的图输入中重复使用,包括:同构图、异构图(:ref:guide_cn-graph-heterogeneous)和子图块(:ref:guide_cn-minibatch)。


SAGEConv的数学公式如下:

56.png


源节点特征 feat_src 和目标节点特征 feat_dst 需要根据图类型被指定。

用于指定图类型并将 feat 扩展为 feat_srcfeat_dst 的函数是 :meth:~dgl.utils.expand_as_pair

该函数的细节如下所示。

def expand_as_pair(input_, g=None):
        if isinstance(input_, tuple):
            # 二分图的情况
            return input_
        elif g is not None and g.is_block:
            # 子图块的情况
            if isinstance(input_, Mapping):
                input_dst = {
                    k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                    for k, v in input_.items()}
            else:
                input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
            return input_, input_dst
        else:
            # 同构图的情况
            return input_, input_


对于同构图上的全图训练,源节点和目标节点相同,它们都是图中的所有节点。

在异构图的情况下,图可以分为几个二分图,每种关系对应一个。关系表示为 (src_type, edge_type, dst_dtype)

当输入特征 feat 是1个元组时,图将会被视为二分图。元组中的第1个元素为源节点特征,第2个元素为目标节点特征。

在小批次训练中,计算应用于给定的一堆目标节点所采样的子图。子图在DGL中称为区块(block)。


在区块创建的阶段,dst nodes 位于节点列表的最前面。通过索引 [0:g.number_of_dst_nodes()] 可以找到 feat_dst

确定 feat_srcfeat_dst 之后,以上3种图类型的计算方法是相同的。

消息传递和聚合

import dgl.function as fn
import torch.nn.functional as F
from dgl.utils import check_eq_shape
       if self._aggre_type == 'mean':
            graph.srcdata['h'] = feat_src
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
            h_neigh = graph.dstdata['neigh']
        elif self._aggre_type == 'gcn':
                check_eq_shape(feat)
                graph.srcdata['h'] = feat_src
                graph.dstdata['h'] = feat_dst
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
                # 除以入度
                degs = graph.in_degrees().to(feat_dst)
                h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
            elif self._aggre_type == 'max_pool':
                graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
                h_neigh = graph.dstdata['neigh']
            else:
                raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))
            # GraphSAGE中gcn聚合不需要fc_self
            if self._aggre_type == 'gcn':
                rst = self.fc_neigh(h_neigh)
            else:
                rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)


上面的代码执行了消息传递和聚合的计算。这部分代码会因模块而异。

聚合后,更新特征作为输出

# 激活函数
 if self.activation is not None:
        rst = self.activation(rst)
       # 归一化
   if self.norm is not None:
        rst = self.norm(rst)
        return rst


forward() 函数的最后一部分是在完成消息聚合后更新节点的特征。

常见的更新操作是根据构造函数中设置的选项来应用激活函数和进行归一化。


3 简单的图分类任务


在本教程中,我们将学习如何使用 DGL 执行图分类,这个例子的任务目标就是对下面显示的八种拓扑类型Grpah进行分类。


57.png


这里我们直接使用 DGL 中合成数据集 data.MiniGCDataset。数据集有八种不同类型的图,每个类都有相同数量的图样本

from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
# A dataset with 80 samples, each graph is
# of size [10, 20]
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()

Using backend: pytorch


58.png

创建graph的批数据


59.png

image

import dgl
import torch
def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels,dtype=torch.long)


构建Graph分类器


60.png

image

from dgl.nn.pytorch import GraphConv
import torch.nn as nn
import torch.nn.functional as F
class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)
    def forward(self, g):
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
        h = g.in_degrees().view(-1, 1).float()
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h
        # Calculate graph representation by averaging all the node representations.
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)

import torch.optim as optim
from torch.utils.data import DataLoader
# Create training and test sets.
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
# Use PyTorch's DataLoader and the collate function
# defined before.
data_loader = DataLoader(trainset, batch_size=32, shuffle=True,
                         collate_fn=collate)
# Create model
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()
epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

Epoch 0, loss 2.0010
Epoch 1, loss 1.9744
Epoch 2, loss 1.9551
Epoch 3, loss 1.9444
Epoch 4, loss 1.9318
Epoch 5, loss 1.9170
Epoch 6, loss 1.8928
Epoch 7, loss 1.8573
Epoch 8, loss 1.8212
Epoch 9, loss 1.7715
Epoch 10, loss 1.7152
Epoch 11, loss 1.6570
Epoch 12, loss 1.5885
Epoch 13, loss 1.5308
Epoch 14, loss 1.4719
Epoch 15, loss 1.4158
Epoch 16, loss 1.3515
Epoch 17, loss 1.2963
Epoch 18, loss 1.2417
Epoch 19, loss 1.1978
Epoch 20, loss 1.1698
Epoch 21, loss 1.1086
Epoch 22, loss 1.0780
Epoch 23, loss 1.0459
Epoch 24, loss 1.0192
Epoch 25, loss 1.0017
Epoch 26, loss 1.0297
Epoch 27, loss 0.9784
Epoch 28, loss 0.9486
Epoch 29, loss 0.9327
Epoch 30, loss 0.9133
Epoch 31, loss 0.9265
Epoch 32, loss 0.9177
Epoch 33, loss 0.9303
Epoch 34, loss 0.8666
Epoch 35, loss 0.8639
Epoch 36, loss 0.8474
Epoch 37, loss 0.8858
Epoch 38, loss 0.8393
Epoch 39, loss 0.8306
Epoch 40, loss 0.8204
Epoch 41, loss 0.8057
Epoch 42, loss 0.7998
Epoch 43, loss 0.7909
Epoch 44, loss 0.7840
Epoch 45, loss 0.7807
Epoch 46, loss 0.7882
Epoch 47, loss 0.7701
Epoch 48, loss 0.7612
Epoch 49, loss 0.7563
Epoch 50, loss 0.7430
Epoch 51, loss 0.7354
Epoch 52, loss 0.7357
Epoch 53, loss 0.7326
Epoch 54, loss 0.7249
Epoch 55, loss 0.7181
Epoch 56, loss 0.7146
Epoch 57, loss 0.7306
Epoch 58, loss 0.7143
Epoch 59, loss 0.7018
Epoch 60, loss 0.7130
Epoch 61, loss 0.7003
Epoch 62, loss 0.6977
Epoch 63, loss 0.7120
Epoch 64, loss 0.6979
Epoch 65, loss 0.7370
Epoch 66, loss 0.7223
Epoch 67, loss 0.6980
Epoch 68, loss 0.6891
Epoch 69, loss 0.6715
Epoch 70, loss 0.6736
Epoch 71, loss 0.6709
Epoch 72, loss 0.6583
Epoch 73, loss 0.6717
Epoch 74, loss 0.6683
Epoch 75, loss 0.6656
Epoch 76, loss 0.6477
Epoch 77, loss 0.6414
Epoch 78, loss 0.6442
Epoch 79, loss 0.6398

plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()


61.png

model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))

Accuracy of sampled predictions on the test set: 58.7500%
Accuracy of argmax predictions on the test set: 62.500000%


相关实践学习
【文生图】一键部署Stable Diffusion基于函数计算
本实验教你如何在函数计算FC上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。函数计算提供一定的免费额度供用户使用。本实验答疑钉钉群:29290019867
建立 Serverless 思维
本课程包括: Serverless 应用引擎的概念, 为开发者带来的实际价值, 以及让您了解常见的 Serverless 架构模式
相关文章
|
10天前
|
SQL 安全 前端开发
PHP与现代Web开发:构建高效的网络应用
【10月更文挑战第37天】在数字化时代,PHP作为一门强大的服务器端脚本语言,持续影响着Web开发的面貌。本文将深入探讨PHP在现代Web开发中的角色,包括其核心优势、面临的挑战以及如何利用PHP构建高效、安全的网络应用。通过具体代码示例和最佳实践的分享,旨在为开发者提供实用指南,帮助他们在不断变化的技术环境中保持竞争力。
|
14天前
|
监控 安全 网络安全
企业网络安全:构建高效的信息安全管理体系
企业网络安全:构建高效的信息安全管理体系
45 5
|
13天前
|
机器学习/深度学习 TensorFlow 算法框架/工具
利用Python和TensorFlow构建简单神经网络进行图像分类
利用Python和TensorFlow构建简单神经网络进行图像分类
39 3
|
23天前
|
数据采集 存储 机器学习/深度学习
构建高效的Python网络爬虫
【10月更文挑战第25天】本文将引导你通过Python编程语言实现一个高效网络爬虫。我们将从基础的爬虫概念出发,逐步讲解如何利用Python强大的库和框架来爬取、解析网页数据,以及存储和管理这些数据。文章旨在为初学者提供一个清晰的爬虫开发路径,同时为有经验的开发者提供一些高级技巧。
16 1
|
16天前
|
存储 安全 网络安全
|
7天前
|
SQL 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【10月更文挑战第40天】在数字化时代,网络安全和信息安全已成为我们生活中不可或缺的一部分。本文将介绍网络安全漏洞、加密技术以及安全意识等方面的知识,帮助读者更好地了解网络安全的重要性,并提供一些实用的技巧和建议,以保护个人和组织的信息安全。
29 6
|
1天前
|
安全 网络安全 数据安全/隐私保护
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
在数字化时代,网络安全和信息安全已成为我们生活中不可或缺的一部分。本文将介绍网络安全漏洞、加密技术和安全意识等方面的知识,并提供一些实用的技巧和建议,帮助读者更好地保护自己的网络安全和信息安全。
|
2天前
|
安全 算法 网络协议
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
在数字时代,网络安全和信息安全已经成为了我们生活中不可或缺的一部分。本文将介绍网络安全漏洞、加密技术和安全意识等方面的内容,帮助读者更好地了解网络安全的重要性和应对措施。通过阅读本文,您将了解到网络安全的基本概念、常见的网络安全漏洞、加密技术的原理和应用以及如何提高个人和组织的网络安全意识。
|
4天前
|
存储 安全 算法
网络安全与信息安全:漏洞、加密与意识的三重防线
在数字时代的浪潮中,网络安全与信息安全成为维护数据完整性、确保个人隐私和企业资产安全的基石。本文将深入探讨网络漏洞的成因、加密技术的应用以及安全意识的培养,旨在通过技术与教育的结合,构建起一道坚固的防御体系。我们将从实际案例出发,分析常见的网络安全威胁,揭示如何通过加密算法保护数据安全,并强调提升个人和组织的安全意识在防范网络攻击中的重要性。
|
1天前
|
监控 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
在数字化时代,网络安全和信息安全已成为全球关注的焦点。本文将探讨网络安全漏洞、加密技术以及安全意识的重要性,并提供一些实用的建议来保护个人和组织的数据安全。我们将从网络安全漏洞的识别和防范开始,然后介绍加密技术的原理和应用,最后强调安全意识在维护网络安全中的关键作用。无论你是个人用户还是企业管理者,这篇文章都将为你提供有价值的信息和指导。

热门文章

最新文章