基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络

简介: 本文详细介绍了如何在昇腾平台上使用PyTorch实现GraphSage算法,在CiteSeer数据集上进行图神经网络的分类训练。内容涵盖GraphSage的创新点、算法原理、网络架构及实战代码分析,通过采样和聚合方法高效处理大规模图数据。实验结果显示,模型在CiteSeer数据集上的分类准确率达到66.5%。

本文主要介绍了如何在昇腾上,使用pytorch对经典的图神经网络GraphSage在论文引用CiteSeer数据集上进行分类训练的实战讲解。


内容包括GraphSage创新点分析、GraphSage算法原理、GraphSage网络架构剖析与GraphSage网络模型代码实战分析等等。


的目录结构安排如下所示:


- GraphSage创新点分析

- GraphSage算法原理

- GraphSage网络架构剖析

- GraphSage网络用于CiteSeer数据集分类实战


## GraphSage创新点分析


- 本文提出了一种归纳式学习模型,可以得到新点/新图的表征。

- 模型通过学习一组函数来得到点的表征。之前的随机游走方式则是先随机初始化点的表征,然后通过模型的训练更新点的表征来获取点的表征,这样无法进行归纳式学习。

- 采样并汇聚点的邻居特征与节点的特征拼接得到点的特征。

- GraphSAGE算法在直推式和归纳式学习均达到最优效果。


## GraphSage算法原理


GCN网络每次学习都需要将整个图送入显存/内存中,资源消耗巨大。另外使用整个图结构进行学习,导致了GCN的学习的固化,图中一旦新增节点,整个图的学习都需要重新进行。这两点对于大数据集和项目实际落地来说,是巨大的阻碍。


我们知道,GCN网络的每一次卷积的过程,每个节点都是只与自己周围的信息节点进行交互,对于单层的GCN网络而言,每个节点只能够接触到自己一跳以内的节点信息,两层的话则是两跳,通过一跳邻居间接传播过来。

be997958-de94-424c-a8fc-3fd3b0fd1e50.png


假设需要图中一个点的卷积的结果,并且只使用了一层GCN,那么实际上我只需要这个节点和它的全部邻居节点,图中的其他节点是没有意义的;如果使用了两层GCN,那么我只需要这个节点和它的两跳以内的全部邻居,以此类推。因为更远节点的信息根本无法到达这个节点。


此外,当图中的某些节点是有几百上千个邻居的超级节点,对于这种节点,哪怕只进行一跳邻居采样,仍然会导致计算困难。GraphSage网络采用抽取出一部分待训练节点和它们的N跳内邻居可以完成这些节点的训练,与此同时对于超级节点的情况,提出了采样的思路,即只对目标节点的邻居进行一定数量的采样,然后通过这次被采样出来的节点和目标节点进行计算,从而逼近完全聚合的效果。


GraphSAGE模型,是一种在图上的通用的归纳式的框架,利用节点特征信息(例如文本属性)来高效地为训练阶段未见节点生成embedding。该模型学习的不是节点的embedding向量,而是学习一种聚合方式,即如何通过从一个节点的局部邻居采样并聚合顶点特征,得到节点最终embedding表征。 当学习到适合的聚合函数后,可以迅速应用到未见过的图上,得到未见过的节点embedding。


因此**采样**与**聚合**是GraphSage网络的两大主要工作,通过随机采样的方式从整张图中抽出一张子图近似替换原始图,然后在该子图上进行聚合计算提取信息特征。


/*** 算法流程解读 ***/


第一个 for循环针对层数进行遍历,表示进行多少层的GraphSAGE, 第二个for循环用于遍历Graph中的所有节点, 针对每个节点, 对邻居进行采样得到$N(v)$邻居节点集合,然后遍历该集合使用$AGGREGATE_k(.)$对邻居节点信息进行聚合得到$\mathbf{h}_{N(v)}^{k}$,最后,通过对邻居节点的及目标节点上衣节点信息$\mathbf{h}_{u}^{k - 1}$进行拼接, 经过非线性变换后赋得到v节点在当前k层的节点权重值。

5cc2dc49-a15e-478e-a542-ce36f416366b.PNG

## GraphSage网络架构剖析


3c1fdfa7-b45a-45dc-bc48-d324bf4e21e0.PNG

GraphSAGE不是试图学习一个图上所有node的embedding,而是学习一个为每个node产生embedding的映射。 GraphSage框架中包含两个很重要的操作:Sample采样和Aggregate聚合。这也是其名字GraphSage(Graph SAmple and aggreGatE)的由来。GraphSAGE 主要分两步:采样、聚合。GraphSAGE的采样方式是邻居采样,邻居采样的意思是在某个节点的邻居节点中选择几个节点作为原节点的一阶邻居,之后对在新采样的节点的邻居中继续选择节点作为原节点的二阶节点,以此类推。


文中不是对每个顶点都训练一个单独的embeddding向量,而是训练了一组aggregator functions,这些函数学习如何从一个顶点的局部邻居聚合特征信息。每个聚合函数从一个顶点的不同的hops或者说不同的搜索深度聚合信息。测试或是推断的时候,使用训练好的系统,通过学习到的聚合函数来对完全未见过的顶点生成embedding。


上图包含下述三个步骤:


- 对图中每个顶点邻居顶点进行采样,因为每个节点的度是不一致的,为了计算高效, 为每个节点采样固定数量的邻居

- 根据聚合函数聚合邻居顶点蕴含的信息

- 得到图中各顶点的向量表示供下游任务使用


## GraphSage网络用于CiteSeer数据集分类实战


导入torch相关库,'functional'集成了一些非线性算子函数,例如Relu等。

import torch
import torch.nn.functional as F

由于torch_geometric中集成了单层的SAGEConv模块,这里直接进行导入,若有兴趣可以自行实现该类,注意输入与输出对齐即可。此外,数据集用的是CiteSeer,该数据集也直接集成在Planetoid模块中,这里也需要将其import进来。

# 导入GraphSAGE层
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid

本实验需要跑在Npu上,因此将Npu相关库导入,'transfer_to_npu'可以使模型快速的迁移到Npu上进行运行。

#导入Npu相关库
import torch_npu
from torch_npu.contrib import transfer_to_npu

## CiteSeer数据集介绍

12119f91-229a-4a65-8426-9199d1011196.png


CiteSeer是一个学术论文数据集,主要涉及计算机科学领域。它由NEC研究院开发,基于自动引文索引(ACI)机制,提供了一种通过引文链接来检索文献的方式。


CiteSeer数据集由学术论文组成,每篇论文被视为一个节点,引用关系被视为边,总共包含6个类别,对应图中6种不同颜色的节点。


包含**3327篇论文也就是3327个节点**,每个**节点有一个3703维的二进制特征向量**,用来表示论文的内容,其中特征向量采用词袋模型(Bag of Words)表示,即每个特征维度对应一个词汇表中的词,值为1表示该词在论文中出现,为0表示未出现。在这些论文中共组成**4732条边**表示论文与论文的引用关系。6个类别分别是Agents、Artificial、Intelligence、Database、Information Retrieval、Machine Learning与Human-Computer Interaction。


加载数据集,root为下载路径的保存默认保存位置,若下载不下来可手动下载后保存在指定路径即可。

print("===== begin Download Dadasat=====\n")
dataset = Planetoid(root='/home/pengyongrong/workspace/data', name='CiteSeer')
print("===== Download Dadasat finished=====\n")

print("dataset num_features  is:  ", dataset.num_features)
print("dataset.num_classes is:  ", dataset.num_classes)

print("dataset.edge_index is:  ", dataset.edge_index)

print("train data is:   ", dataset.data)
print("dataset0 is:  ", dataset[0])

print("train data mask is:   ", dataset.train_mask, "num train is: ", (dataset.train_mask ==True).sum().item())
print("val data mask is:   ",dataset.val_mask, "num val is: ", (dataset.val_mask ==True).sum().item())
print("test data mask is:   ",dataset.test_mask,  "num test is: ", (dataset.test_mask ==True).sum().item())

===== begin Download Dadasat=====


===== Download Dadasat finished=====


dataset num_features  is:   3703

dataset.num_classes is:   6

dataset.edge_index is:   tensor([[ 628,  158,  486,  ..., 2820, 1643,   33],

       [   0,    1,    1,  ..., 3324, 3325, 3326]])

train data is:    Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])

dataset0 is:   Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])

train data mask is:    tensor([ True,  True,  True,  ..., False, False, False]) num train is:  120

val data mask is:    tensor([False, False, False,  ..., False, False, False]) num val is:  500

test data mask is:    tensor([False, False, False,  ...,  True,  True,  True]) num test is:  1000


/home/pengyongrong/miniconda3/envs/AscendCExperiments/lib/python3.9/site-packages/torch_geometric/data/in_memory_dataset.py:300: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.

 warnings.warn(msg)


搭建两层GraphSAGE网络,其中sage1与sage2分别表示第一层与第二层,这里如果有需要可以搭建多层的GraphSage,注意保持输入链接出大小相互匹配即可。

class GraphSAGE_NET(torch.nn.Module):

    def __init__(self, feature, hidden, classes):
        super(GraphSAGE_NET, self).__init__()
        self.sage1 = SAGEConv(feature, hidden)
        self.sage2 = SAGEConv(hidden, classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.sage1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.sage2(x, edge_index)

        return F.log_softmax(x, dim=1)

定义设备跑在Npu上,这里如果需要替换成Gpu或Cpu,则替换成'cuda'或'cpu'即可。

device = 'npu'

定义GraphSAGE网络,中间隐藏层节点个数定义为16,'dataset.num_classes'为先前数据集中总的类别数,这里是7类。'to()'的作用是将该加载到指定模型设备上。优化器用的是'optim'中的'Adam'。

model = GraphSAGE_NET(dataset.num_node_features, 16, dataset.num_classes).to(device) 
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

开始训练模型,指定训练次数200次,训练后采用极大似然用作损失函数计算损失,然后进行反向传播更新模型的参数,训练完成后,用验证集中的数据对模型效果进行验证,最后打印模型的准确率为0.665。

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()


model.eval()
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
print('GraphSAGE Accuracy: {:.4f}'.format(acc))

GraphSAGE Accuracy: 0.6650

**显存使用情况:** 整个训练过程的存使用情况可以通过"npu-smi info"命令在终端查看,因此本文实验只用到了单个npu卡(也就是chip 0),存占用约943M,对存、精度或性能优化有兴趣的可以自行尝试进行优化。

57dd1185-474f-4b9a-be1e-ed610995da01.png


# Reference


[1] Hamilton, William L ,  R. Ying , and  J. Leskovec . "Inductive Representation Learning on Large Graphs." (2017).

相关文章
|
7天前
|
存储 运维 安全
云上金融量化策略回测方案与最佳实践
2024年11月29日,阿里云在上海举办金融量化策略回测Workshop,汇聚多位行业专家,围绕量化投资的最佳实践、数据隐私安全、量化策略回测方案等议题进行深入探讨。活动特别设计了动手实践环节,帮助参会者亲身体验阿里云产品功能,涵盖EHPC量化回测和Argo Workflows量化回测两大主题,旨在提升量化投研效率与安全性。
云上金融量化策略回测方案与最佳实践
|
9天前
|
人工智能 自然语言处理 前端开发
从0开始打造一款APP:前端+搭建本机服务,定制暖冬卫衣先到先得
通义灵码携手科技博主@玺哥超carry 打造全网第一个完整的、面向普通人的自然语言编程教程。完全使用 AI,再配合简单易懂的方法,只要你会打字,就能真正做出一个完整的应用。
8494 20
|
13天前
|
Cloud Native Apache 流计算
资料合集|Flink Forward Asia 2024 上海站
Apache Flink 年度技术盛会聚焦“回顾过去,展望未来”,涵盖流式湖仓、流批一体、Data+AI 等八大核心议题,近百家厂商参与,深入探讨前沿技术发展。小松鼠为大家整理了 FFA 2024 演讲 PPT ,可在线阅读和下载。
4566 11
资料合集|Flink Forward Asia 2024 上海站
|
13天前
|
自然语言处理 数据可视化 API
Qwen系列模型+GraphRAG/LightRAG/Kotaemon从0开始构建中医方剂大模型知识图谱问答
本文详细记录了作者在短时间内尝试构建中医药知识图谱的过程,涵盖了GraphRAG、LightRAG和Kotaemon三种图RAG架构的对比与应用。通过实际操作,作者不仅展示了如何利用这些工具构建知识图谱,还指出了每种工具的优势和局限性。尽管初步构建的知识图谱在数据处理、实体识别和关系抽取等方面存在不足,但为后续的优化和改进提供了宝贵的经验和方向。此外,文章强调了知识图谱构建不仅仅是技术问题,还需要深入整合领域知识和满足用户需求,体现了跨学科合作的重要性。
|
21天前
|
人工智能 自动驾驶 大数据
预告 | 阿里云邀您参加2024中国生成式AI大会上海站,马上报名
大会以“智能跃进 创造无限”为主题,设置主会场峰会、分会场研讨会及展览区,聚焦大模型、AI Infra等热点议题。阿里云智算集群产品解决方案负责人丛培岩将出席并发表《高性能智算集群设计思考与实践》主题演讲。观众报名现已开放。
|
9天前
|
人工智能 容器
三句话开发一个刮刮乐小游戏!暖ta一整个冬天!
本文介绍了如何利用千问开发一款情侣刮刮乐小游戏,通过三步简单指令实现从单个功能到整体框架,再到多端优化的过程,旨在为生活增添乐趣,促进情感交流。在线体验地址已提供,鼓励读者动手尝试,探索编程与AI结合的无限可能。
三句话开发一个刮刮乐小游戏!暖ta一整个冬天!
|
1月前
|
存储 人工智能 弹性计算
阿里云弹性计算_加速计算专场精华概览 | 2024云栖大会回顾
2024年9月19-21日,2024云栖大会在杭州云栖小镇举行,阿里云智能集团资深技术专家、异构计算产品技术负责人王超等多位产品、技术专家,共同带来了题为《AI Infra的前沿技术与应用实践》的专场session。本次专场重点介绍了阿里云AI Infra 产品架构与技术能力,及用户如何使用阿里云灵骏产品进行AI大模型开发、训练和应用。围绕当下大模型训练和推理的技术难点,专家们分享了如何在阿里云上实现稳定、高效、经济的大模型训练,并通过多个客户案例展示了云上大模型训练的显著优势。
104589 10
|
8天前
|
消息中间件 人工智能 运维
12月更文特别场——寻找用云高手,分享云&AI实践
我们寻找你,用云高手,欢迎分享你的真知灼见!
728 45
|
6天前
|
弹性计算 运维 监控
阿里云云服务诊断工具:合作伙伴架构师的深度洞察与优化建议
作为阿里云的合作伙伴架构师,我深入体验了其云服务诊断工具,该工具通过实时监控与历史趋势分析,自动化检查并提供详细的诊断报告,极大提升了运维效率和系统稳定性,特别在处理ECS实例资源不可用等问题时表现突出。此外,它支持预防性维护,帮助识别潜在问题,减少业务中断。尽管如此,仍建议增强诊断效能、扩大云产品覆盖范围、提供自定义诊断选项、加强教育与培训资源、集成第三方工具,以进一步提升用户体验。
640 243
|
3天前
|
弹性计算 运维 监控
云服务测评 | 基于云服务诊断全方位监管云产品
本文介绍了阿里云的云服务诊断功能,包括健康状态和诊断两大核心功能。作者通过个人账号体验了该服务,指出其在监控云资源状态和快速排查异常方面的优势,同时也提出了一些改进建议,如增加告警配置入口和扩大诊断范围等。