5分钟掌握开源图神经网络框架DGL使用

简介: 近几年神经网络在人工智能领域的成功应用,让它备受关注和热捧。但是,它自身依然具有本质上的局限性,以往的神经网络都是限定在欧式空间内,这和大多数实际应用场景并不符合,因此,也阻碍了它在很多领域的实际落地应用。

前言

近几年神经网络在人工智能领域的成功应用,让它备受关注和热捧。但是,它自身依然具有本质上的局限性,以往的神经网络都是限定在欧式空间内,这和大多数实际应用场景并不符合,因此,也阻碍了它在很多领域的实际落地应用。

图神经网络的出现能够有效的解决这一弊端,因此,近两年图神经网络成为了一个新的热门领域,由于它在关系表示方面更加契合现实的应用场景,使得它在搜索推荐、交通预测等领域获取了很大的成功。

但是,随之而来的问题就是,如何快速构建一个图神经网络模型成为了一个很棘手的问题。虽然,tensorflow、pytorch、mxnet在CNN、RNN领域取得了显著的成果,但是对于构建图神经网络却捉襟见肘。

于是针对这个问题,NYU和亚马逊联合开发了DGL框架,针对graph做了定制,不管是API设计还是性能优化,本文就来揭开DGL这个框架的神秘面纱吧!

DGL简介

今天给大家介绍一下开源图神经网络(Graph Neural Network,GNN)计算框架Deep Graph Library (DGL)[1]。该框架由New York University(NYU)和亚马逊公司主导开发,旨在为开发者提供一个基于现有深度学习张量计算框架可以开发GNN算法的平台。该框架0.1版本release是在2018年12月7日,当时支持的深度学习框架为pytroch和mxnet,具备基本的API,

28.jpg

给出了基本的GNN模型;今年1月24日,DGL 发布了0.4.2版本,这个新版本具备了很多 新的特性,包括支持异构图,支持后端等,这次主要给大家介绍的就是这个版本。

首先附上DGL的架构图,如图1所示,从这张图上我们可以看到:DGL最上层是应用层,目

前包含生命科学和知识图谱两类,代码路径/dgl/apps;下层有DGL定义的图、NN Modules、图上信息传递接口、图算法等;再下层是运行时,与深度学习框架对接。其中不得不提的的是图上的消息传递接口,DGL一直以来坚持以消息传递的形式来表示GNN算法,体现图上信息以send及receive的形式进行发送及接收。官网用PageRank的实现过程来说明了相关接口的使用方式https://docs.dgl.ai/tutorials/basics/3_pagerank.html,这里就不详细介绍这部分说明了,感兴趣的同学可以去看toturial相关内容。DGL工程定义了自己的图引擎,包括图的基本结构及操作等,这部分工作是利用C++实现的,通过编译为lib库的形式让python层进行调用。

DGL使用

本文主要想以Graph Covolutional Network(GCN)的是实现过程来简要介绍DGL的实现流程。GCN实现代码的路径为/dgl/examples/pytorch/gcn/gcn.py,模型训练入口为/dgl/examples/pytorch/gcn/train.py,同一目录下有readme文档,这里以pytorch的实现来介绍是因为基于tensorflow没跑通core dump了(满脸哭泣),以及不会用mxnet(学艺不精)不过这些都是细节,不影响理解核心内容。这里默认广大人民群众都很了解GCN这个网络了,所用数据集为cora,一个半监督学习任务,目标是对数据集内的文献进行分类。train.py内main函数负责控制数据加载,模型训练(下图),

29.jpg

其中train.py line 25~line 66主要进行数据读取,预处理等工作,截止到train.py line 66,利用cora的数据集构成的图还是一个network的图,train.py line 67将network图转换为DGL的图,这里体现了DGL的一个特性,支持对network图进行类型转换。train.py Line 79~94开始实例化模型,定义loss及优化器(下图)。

30.png

而在GCN模型中,核心是GraphConv的实现:dgl/python/dgl/nn/pytorch/conv/graphconv.py(下图)。

31.jpg

在GCN网络中,每个目标节点的特征(feature)更新,依靠的是将其一阶邻居的特征加和平均并与中心节点求和。计算公式为:

32.png

其中A为邻接矩阵(NxN),表征节点间连接关系,需进行归一化处理以实现求和平均;X为节点特征矩阵NxDinputW为权重(weight)矩阵DinputxDoutput。每进行一次AxW操作,图中节点特征就进行了一次更新。

从信息传递的角度理解,每个节点接收来自其一阶邻居的信息,对自身特征进行更新。注意,在GCN中如果没有权重W,这个过程与深度学习是无关的,正是因为W需要不断迭代更新,才需要深度学习框架。众所周知,每个深度学习框架都有自己的tensor类型,在计算过程中涉及的变量均为框架内tensor。以上述计算过程为例,XxWXW都是pytroch的tensor;而从DGL的角度来看,节点信息需要在DGL图上进行传递,此时要求节点信息是DGL可以操作的数据类型。DGL本身是定义了其数据类型NDArray的,对该类型数据可以进行操作。这样看来框架tensor想在DGL图上传递是存在鸿沟的,DGL通过DLPack来解决这个问题。

DLPack是Distributed (Deep) Machine Learning Community的一个开源项目,定义了一种开放的内存中tensor结构,用于深度学习框架之间共享tensor;这种数据结构mxnet原生支持,pytroch官方utils支持,tensorflow通过tfdlpack包支持。有了这种中间态存在,之前的问题就迎刃而解了。pytroch中计算出的节点信息首先转换为DLPack,然后再转换为DGL的NDArray,在图上进行传递。

信息传递是在GraphConv.py的line 119~121实现的,其中line 119~120根据节点特征得到需要传递的信息,所调用的fn.copy_src实际定义在/dgl/python/dgl/backend/pytorch/tensor.py中(下图)。因为这部分操作是嵌入在pytorch中的,所以定义了前向和后向计算的过程,以保证梯度回传顺利。

33.jpg

34.jpg

可以看到tensor.py中line 383就是把输入的pytorch tensor转换为DGL的NDArray以进行后续操作。其中核心的line 385行K.copy_reduce的定义在/dgl/src/kernel/binary_reduce.cc(上图)中。

在这个过程中可以看到DGL在C++层自定义了图操作并且是在host上进行的,也就说上面的tensor转换其实是把device(GPU)上的数据拉回host进行处理,这其实依赖pytorch支持device上数据回传的功能。这也是为什么DGL在0.4.2版本才支持tensorflow,因为tensorflow在1.15版本才提供device上回传tensor的接口。

GraphConv.py的line 121是节点信息更新的过程,这个过程与CopyReduce类似,实现的本质是AX的过程,调用的其实是cusparse提供的功能,具体实现参见/dgl/src/kernel/cuda/binary_reduce_sum.cu。这里利用cusparse本质是因为邻接矩阵A其实是一个稀疏矩阵,cusparse可以对稀疏tensor处理加速。完成节点信心更新后,GCN的核心计算过程就完成了,后续包括一些加bias和经过激活函数的操作,不再过多介绍。

以上以GCN的实现过程为例,给大家介绍了DGL是如何结合深度学习框架实现了图上的消息传递过程,值得提到的一点是,在训练过程中,graphconv.py的line 108行建立了一个局部变量,用以保存传入的图,之后的信息汇聚都在这个局部图上进行,并不影响原图,每个forward过程之后,这个局部变量都回被销毁。

本文部分内容介绍的可能不够详细与准确,期待与大家进行讨论。

Ps. DGL example里的模型,并不都是通过消息传递这种机制实现的,比如/dgl/examples/pytorch/recommendation中的PinSage就是用正常的tensor计算实现的,可能DGL也没想好怎么用send和receive来写这个模型吧。

相关文章
|
5天前
|
机器学习/深度学习 人工智能 自然语言处理
ICLR 2024 Spotlight:训练一个图神经网络即可解决图领域所有分类问题!
【2月更文挑战第17天】ICLR 2024 Spotlight:训练一个图神经网络即可解决图领域所有分类问题!
64 2
ICLR 2024 Spotlight:训练一个图神经网络即可解决图领域所有分类问题!
|
3天前
|
安全 网络协议 网络安全
OWASP Top 10 网络安全10大漏洞——A01,源码+原理+手写框架
OWASP Top 10 网络安全10大漏洞——A01,源码+原理+手写框架
|
3天前
|
关系型数据库 MySQL 网络安全
Docker部署MySQL,2024网络安全通用流行框架大全
Docker部署MySQL,2024网络安全通用流行框架大全
|
5天前
|
存储 分布式计算 监控
Hadoop【基础知识 01+02】【分布式文件系统HDFS设计原理+特点+存储原理】(部分图片来源于网络)【分布式计算框架MapReduce核心概念+编程模型+combiner&partitioner+词频统计案例解析与进阶+作业的生命周期】(图片来源于网络)
【4月更文挑战第3天】【分布式文件系统HDFS设计原理+特点+存储原理】(部分图片来源于网络)【分布式计算框架MapReduce核心概念+编程模型+combiner&partitioner+词频统计案例解析与进阶+作业的生命周期】(图片来源于网络)
142 2
|
3天前
|
机器学习/深度学习 数据挖掘 算法框架/工具
想要了解图或图神经网络?没有比看论文更好的方式,面试阿里国际站运营一般会问什么
想要了解图或图神经网络?没有比看论文更好的方式,面试阿里国际站运营一般会问什么
|
3天前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
10 1
|
5天前
|
缓存 网络安全 Android开发
|
5天前
|
存储 网络协议 Linux
RTnet – 灵活的硬实时网络框架
本文介绍了开源项目 RTnet。RTnet 为以太网和其他传输媒体上的硬实时通信提供了一个可定制和可扩展的框架。 本文描述了 RTnet 的架构、核心组件和协议。
23 0
RTnet – 灵活的硬实时网络框架
|
5天前
|
JSON Kubernetes 网络架构
Kubernetes CNI 网络模型及常见开源组件
【4月更文挑战第13天】目前主流的容器网络模型是CoreOS 公司推出的 Container Network Interface(CNI)模型
|
5天前
|
网络协议 Java API
Python网络编程基础(Socket编程)Twisted框架简介
【4月更文挑战第12天】在网络编程的实践中,除了使用基本的Socket API之外,还有许多高级的网络编程库可以帮助我们更高效地构建复杂和健壮的网络应用。这些库通常提供了异步IO、事件驱动、协议实现等高级功能,使得开发者能够专注于业务逻辑的实现,而不用过多关注底层的网络细节。

热门文章

最新文章