基于GCN和DGL实现的图上 node 分类, 值得一看!!!

简介: 基于GCN和DGL实现的图上 node 分类, 值得一看!!!

基于GCN和DGL实现的图上 node 分类, 值得一看!!!


书接上文,我们在 GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推 这篇文章中,初步讲解了 图上机器任务 的大致分类以及图上 有无监督任务loss 的区别,并基于作者自己的理解分析了 图上消息传播 的过程。在上文中,我们使用 GraphSage与DGL技术实现了无监督情境下的 链接预测 ,本文我们将继续使用DGL和GCN网络来实现图上的有监督学习之一的 节点分类 ,下面让我我们开始吧~

(1) 图上Node分类任务基础

从上篇文章文章中我们知道:在图上去预测两个节点之间的关系(或边)是否存在的任务,是 典型的链接(关系)预测 任务。在图上 基于图上结构之间的关系对某个节点进行定性性质的分析 ,例如判断一个用户是否是异常用户等,这种任务则是典型的 节点(Node)分类 任务,对每个节点我们是需要 打上label数据 的,二分类就是 0或1 。


从这里我们也可以 推断 出: 和 传统机器学习任务 类似,如果要进行回归任务的话,我们只需要对节点的label数据赋予需要回归出的浮点数即可。


这里我们需要明确的一点就是:我们无论是进行图上节点的 分类还是回归 任务,我们在最后一层DNN的前一层得到logit的时候,该 logit的数据其实携带了该节点周围的邻居节点的关系信息 ,我们是融合了 该节点 以及其 邻居节点 的信息(通常以Embeding的形式存在) 来 对当前节点 进行的 定性形式 的判断。


并且这里所说的节点以及邻居节点的信息,是指节点的 各个属性 ,比较简单的任务可能就是对每个节点就是一个综合的embeding ,而对于复杂的任务,每个节点可能有很多个向量,则我们就要对 各个向量 分别写 适合该属性数据特性 的聚合处理逻辑以及后面接入全链接DNN,也可以把得到各个属性的embeding用一个简单的网络 融合 之后在进行 信息传播 均是可以的。


这里和传统机器学习的不同就是: 传统机器学习做决断用到的信息 都是 独立同分布 的,而这里做决断用到的部分信息是 依赖图的空间结构上处于邻居位置的信息 的,非独立同分布 的。


历史图相关文章链接如下:


(1)一文揭开图机器学习的面纱,你确定不来看看吗


(2)graphSage还是HAN ?吐血力作综述Graph Embeding 经典好文


(3) 看这里,使用docker部署图深度学习框架GraphLearn使用说明


(4) GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推


本文这里主要讲解的是 基于GCN和DGL实现的图上 Node 分类任务 ,对图上消息传播与各种任务不太熟悉的同学,可以先去看看上面列出的几篇历史文章,其中对图的知识有了大概的讲解。


下面让我们结合代码开始今天的学习吧~

(2) 采样与算子适用性说明

从前一篇文章 GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推 中,我们可以知道:虽然我们使用了dgl 官方实现的Graph Sage 算子,但是Sage算子的提出最初主要是为了解决GCN任务理论上每次均把所有节点以及他们的邻居节点均 load进内存 导致内存放不下或则 训练速度缓慢 的问题,以及对 没见过的节点进行预测 的问题。


对于上面第二点,我们要引入一个概念就是 直推式学习 (transductive learning) 和 归纳式学习 (Inductive learning) 。传统的GCN就是属于直推式学习, 训练 节点embedding 的时候要看到全图的节点, 其根本原因是因为它使用了 拉普拉斯矩阵 。而 Sage算子 ,则是归纳式学习,因为它的基本逻辑是可以 用“你“的邻居信息归纳出没见过的”你“的信息 ,可以处理未见过你但知道你邻居这种类似的问题。

例如:你在北京的朋友都是程序员或则从事互联网的职业,那你在很大程度上也是程序员或从事互联网的职业…


在前一篇文章中,我们虽然使用了 GraphSage算子,但是我们并没有使用Sage算子与底层原理相契合的 邻居采样 的方式。因为链接预测需要对构造没出现过的边作为负样本,文章中选择了对每个种子节点进行 全局负采样 ,这里的采样就是 必须 的。 而对于节点分类这种任务,属于有监督学习任务,节点的正负由外界赋予的label来确定,所以节点采样是 非必须 的,我们 进行采样必然是为了解决某一类问题 。


Sage算子 具有 采样与 聚合 两种的特性,我们分别在采样和聚合的阶段去利用这种特性,当然这两种都用也是可以的。我们如果不采样,则只是利用来它多种的聚合方式的特性,去聚合所有的邻居节点。当然我们也可以用采样的方式,对每个节点去采样一定数目的邻居,参与Sage算子的聚合阶段。


当然,除了对每个节点的 邻居进行采样 ,我们通常也会对大数据集合进行 分batch训练 ,这两种方式都可以 减小数据对内存 的压力。在DGL中,我们可以采用

dgl.dataloading.NodeDataLoader

接口进行针对 节点 的采样操作,而对 边采样 则可以使用 dgl.dataloading.EdgeDataLoader 接口,可以同时对上述说的两个方面进行工程上的实现,**dgl是我心目中处理图机器学习的 yyds ** 这两个采样算子的接口说明,作者将在后续的文章中将继续写一篇介绍比较复杂的任务的文章,在其中将对算子的使用进行说明~


这里主要拿Sage算子来阐述,主要是想说明:针对不同的任务,例如节点和边的任务,采用不同的算子,用这个算子的什么特性我们都是可以根据自己的数据流程灵活定制,我们在拿到代码的时候,可以去深入思考一下,看使用了什么特性以及怎么搭配能更好的处理我们的任务。这里不光针对Sage算子的采样与聚合,以及后面的同异构图的Attention也是如此。

(3) 代码时光

注意:这里的代码是 基于 gcn 和 dgl 实现的痛构图上有监督的节点分类任务 .


开篇吼一嗓子 , talk is cheap , show me the code !!! 下面,就让我们开始coding 吧~

(3.1) 导包

@ 欢迎关注微信公众号:算法全栈之路

import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn
from dgl import AddSelfLoop
from dgl.data import CiteseerGraphDataset, CoraGraphDataset, PubmedGraphDataset

老规矩,先导包。跑起来这个任务,仅仅需要这些包就可以。


(3.2) 模型结构定义

这里我们采用 dgl 官方实现的 graphConv 算子进行邻居节点信息的聚合,不进行邻居节点的采样。

@ 欢迎关注微信公众号:算法全栈之路
class GCN(nn.Module):
    def __init__(self, in_size, hid_size, out_size):
        super().__init__()
        self.layers = nn.ModuleList()
        # two-layer GCN
        self.layers.append(
            dglnn.GraphConv(in_size, hid_size, activation=F.relu)
        )
        self.layers.append(dglnn.GraphConv(hid_size, out_size))
        self.dropout = nn.Dropout(0.5)
    def forward(self, g, features):
        h = features
        for i, layer in enumerate(self.layers):
            if i != 0:
                h = self.dropout(h)
            h = layer(g, h)
        return h


我们可以看到:这里的网络结构选择的是gcn 。在上面的网络结构中,nn.ModuleList中放有2层的GraphConv卷积层,并在其中加入了dropout层。图卷基层之间也是可以加入dropout层的,和传统的深度学习DNN无任何区别。


为了 加深理解 ,我们可以重点关注下gcn模型的初始化参数以及输入输出参数。可以看到:初始化参数包括了模型的输入参数,这里就是节点初始embeding的维度,隐藏维度以及输出维度,h的维度和out_size相同 。在本文,h维度也是和 logit维度相同,等于输出类别数。


(3.3) 模型评估代码

@ 欢迎关注微信公众号:算法全栈之路
# 模型评估,返回准确率 
def evaluate(g, features, labels, mask, model):
    model.eval()
    with torch.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)


这里的评估指标是 准确率,衡量的正负样本预测的准确与否。函数的输入是 G图, feature 是节点的原始属性feature,经过gcn后得到logit, 和标签进行比较。


我们可以看到其中用到了 model.eval() 代码。 它的作用是:在评估过程中不启用 Batch Normalization 和 Dropout,一般在测试的代码之前添加,使得在测试时保证BN层能够用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。

对于Dropout则是用到了所有网络连接,即不进行随机舍弃神经元。


最后这里,我们可以看到并没有进行交叉熵计算的代码,是因为:加上了交叉熵并不会影响计算出的max logit 的 indeces ,也就不会影响模型评估的准确性了。


(3.4) 训练模块

@ 欢迎关注微信公众号:算法全栈之路
def train(g, features, labels, masks, model):
    # define train/val samples, loss function and optimizer
    train_mask = masks[0]
    val_mask = masks[1]
    loss_fcn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    # training loop
    for epoch in range(20):
        model.train()
        logits = model(g, features)
        loss = loss_fcn(logits[train_mask], labels[train_mask])
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        acc = evaluate(g, features, labels, val_mask, model)
        print(
            "Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} ".format(
                epoch, loss.item(), acc
            )
        )

注意: 这里的损失函数采用的是交叉熵损失,而其实 CrossEntropyLoss相当于softmax + log + nllloss , 这也就和我们上面的代码是相互映证的。


优化器选择的是 adam, 当我们不知道选择什么优化器好的时候,无脑选择adam 总是能给我们一个不错的结果。


中间train 模块的输入输出参数,可以在下面的主函数中看到,这里就不再多做说明了。

model.train()和 model.eval() 同理,只不过作用相反。


(3.5) 模型输入参数等主函数

我们这里的测试数据,直接采用了dgl官方提供的数据集, 这里提供了三个数据集。


闲言少叙,看代码吧~

@ 欢迎关注微信公众号:算法全栈之路
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        default="cora",
        help="Dataset name ('cora', 'citeseer', 'pubmed').",
    )
    args = parser.parse_args()
    print(f"Training with DGL built-in GraphConv module.")
    # load and preprocess dataset
    # 添加自环
    transform = (
        AddSelfLoop()
    )  # by default, it will first remove self-loops to prevent duplication
    if args.dataset == "cora":
        data = CoraGraphDataset(transform=transform)
    elif args.dataset == "citeseer":
        data = CiteseerGraphDataset(transform=transform)
    elif args.dataset == "pubmed":
        data = PubmedGraphDataset(transform=transform)
    else:
        raise ValueError("Unknown dataset: {}".format(args.dataset))
    g = data[0]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    g = g.int().to(device)
    features = g.ndata["feat"]
    labels = g.ndata["label"]
    masks = g.ndata["train_mask"], g.ndata["val_mask"], g.ndata["test_mask"]
    # create GCN model
    in_size = features.shape[1]
    out_size = data.num_classes
    model = GCN(in_size, 16, out_size).to(device)
    # model training
    print("Training...")
    train(g, features, labels, masks, model)
    # test the model
    print("Testing...")
    acc = evaluate(g, features, labels, masks[2], model)
    print("Test accuracy {:.4f}".format(acc))


我们可以看到: 上述的代码中,有给 图添加自环 的过程 AddSelfLoop ,是因为添加自环可以 有效缓解图的稀疏性 ,能够提高模型的训练效果。


中间的 g = g.int().to(device) 是把图数据转入到device 中,如果训练model的机器有gpu的话,是可以把 数据复制到gpu的显存 里去的。


把上面的代码复制到一个python文件中是可以完美运行的,我这里的 dgl版本选择的是0.9 。代码本身是非常通俗易懂的,望文可以知其意 ,我就不在过多赘述了哈。如果有任何疑问,欢迎关注公众号讨论~


最后在强调一点: 这里的代码和上一篇文章中的代码用了一样的处理方式:节点 feature 不随着网络的更新而更新,如果要随着网络更新,可以去看 GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推 文章最后介绍的处理方式,或则去关注作者的后面一篇文章,会选择让特征随着网络更新的方式进行更新的代码实现。


宅男民工码字不易,你的关注是我持续输出的最大动力。


接下来作者会继续分享学习与工作中一些有用的、有意思的内容,点点手指头支持一下吧~


了解更多更全内容 : 算法全栈之路


相关文章
|
14天前
|
缓存 JavaScript 前端开发
Node.js模块化的基本概念和分类及使用方法
Node.js模块化的基本概念和分类及使用方法
18 0
|
JavaScript 中间件
Node.js学习笔记----中间件的分类
Node.js学习笔记----中间件的分类
|
机器学习/深度学习 编解码 算法
重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战
重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战
重磅好文透彻理解,异构图上 Node 分类理论与DGL源码实战
|
JSON JavaScript 数据格式
node.js 中模块的分类|学习笔记
快速学习 node.js 中模块的分类
127 0
NodeJs--模块分类
一,模块的基本分类 二,模块的流程 三,简单示例 1,模块student: function add(student){ console.log('add student:'+student); } exports.
885 0
|
2月前
|
JavaScript
NodeJs的安装
文章介绍了Node.js的安装步骤和如何创建第一个Node.js应用。包括从官网下载安装包、安装过程、验证安装是否成功,以及使用Node.js监听端口构建简单服务器的示例代码。
NodeJs的安装
|
24天前
|
JavaScript 开发工具 git
已安装nodejs但是安装hexo报错
已安装nodejs但是安装hexo报错
21 2
|
2月前
|
存储 JavaScript 前端开发
Node 版本控制工具 NVM 的安装和使用(Windows)
本文介绍了NVM(Node Version Manager)的Windows版本——NVM for Windows的安装和使用方法,包括如何安装Node.js的特定版本、列出已安装版本、切换使用不同版本的Node.js,以及其他常用命令,以实现在Windows系统上对Node.js版本的便捷管理。
Node 版本控制工具 NVM 的安装和使用(Windows)
|
18天前
|
Web App开发 JavaScript 前端开发
JavaWeb 22.Node.js_简介和安装
JavaWeb 22.Node.js_简介和安装