Spektral:使用TF2实现经典GNN的开源库

简介: Spektral:使用TF2实现经典GNN的开源库

简介


Spektral工具还发表了论文:

《Graph Neural Networks in TensorFlow and Keras with Spektral》

https://arxiv.org/abs/2006.12138

102.png


github地址:https://github.com/danielegrattarola/spektral/

在本文中,我们介绍了 Spektral,这是一个开源 Python 库,用于使用 TensorFlow 和 Keras 应用程序编程接口构建图神经网络。Spektral 实现了大量的图深度学习方法,包括消息传递和池化运算符,以及用于处理图和加载流行基准数据集的实用程序。这个库的目的是为创建图神经网络提供基本的构建块,重点是 Keras 所基于的用户友好性和快速原型设计的指导原则。因此,Spektral 适合绝对的初学者和专业的深度学习从业者。


主要网络


Spektral 实现了一些主流的图深度学习层,包括:


安装


pip安装:

pip install spektral


源码安装:

git clone https://github.com/danielegrattarola/spektral.git
cd spektral
python setup.py install  # Or 'pip install .'


Spektral实现GCN


对于TF爱好者很友好:

import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GCNConv
from spektral.models.gcn import GCN
from spektral.transforms import AdjToSpTensor, LayerPreprocess
learning_rate = 1e-2
seed = 0
epochs = 200
patience = 10
data = "cora"
tf.random.set_seed(seed=seed)  # make weight initialization reproducible
# Load data
dataset = Citation(
    data, normalize_x=True, transforms=[LayerPreprocess(GCNConv), AdjToSpTensor()]
)
# We convert the binary masks to sample weights so that we can compute the
# average loss over the nodes (following original implementation by
# Kipf & Welling)
def mask_to_weights(mask):
    return mask.astype(np.float32) / np.count_nonzero(mask)
weights_tr, weights_va, weights_te = (
    mask_to_weights(mask)
    for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)
model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)
model.compile(
    optimizer=Adam(learning_rate),
    loss=CategoricalCrossentropy(reduction="sum"),
    weighted_metrics=["acc"],
)
# Train model
loader_tr = SingleLoader(dataset, sample_weights=weights_tr)
loader_va = SingleLoader(dataset, sample_weights=weights_va)
model.fit(
    loader_tr.load(),
    steps_per_epoch=loader_tr.steps_per_epoch,
    validation_data=loader_va.load(),
    validation_steps=loader_va.steps_per_epoch,
    epochs=epochs,
    callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],
)
# Evaluate model
print("Evaluating model.")
loader_te = SingleLoader(dataset, sample_weights=weights_te)
eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)
print("Done.\n" "Test loss: {}\n" "Test accuracy: {}".format(*eval_results))


相关文章
|
8月前
|
机器学习/深度学习 算法 PyTorch
python手把手搭建图像多分类神经网络-代码教程(手动搭建残差网络、mobileNET)
python手把手搭建图像多分类神经网络-代码教程(手动搭建残差网络、mobileNET)
|
机器学习/深度学习 人工智能 并行计算
5分钟掌握开源图神经网络框架DGL使用
近几年神经网络在人工智能领域的成功应用,让它备受关注和热捧。但是,它自身依然具有本质上的局限性,以往的神经网络都是限定在欧式空间内,这和大多数实际应用场景并不符合,因此,也阻碍了它在很多领域的实际落地应用。
5分钟掌握开源图神经网络框架DGL使用
|
5月前
|
机器学习/深度学习 监控 数据可视化
|
8月前
|
机器学习/深度学习 数据采集 PyTorch
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
PyTorch搭建卷积神经网络(ResNet-50网络)进行图像分类实战(附源码和数据集)
427 1
|
8月前
|
机器学习/深度学习 编解码 TensorFlow
【Keras+计算机视觉+Tensorflow】生成对抗神经网络中DCGAN、CycleGAN网络的讲解(图文解释 超详细)
【Keras+计算机视觉+Tensorflow】生成对抗神经网络中DCGAN、CycleGAN网络的讲解(图文解释 超详细)
170 0
|
机器学习/深度学习 存储 并行计算
【Pytorch神经网络理论篇】 27 图神经网络DGL库:简介+安装+卸载+数据集+PYG库+NetWorkx库
DGL库是由纽约大学和亚马逊联手推出的图神经网络框架,支持对异构图的处理,开源相关异构图神经网络的代码,在GCMC、RGCN等业内知名的模型实现上也取得了很好的效果。
1722 0
|
机器学习/深度学习 数据采集 人工智能
【Pytorch神经网络理论篇】 40 Transformers中的词表工具Tokenizer
在Transformers库中,提供了一个通用的词表工具Tokenizer,该工具是用Rust编写的,其可以实现NLP任务中数据预处理环节的相关任务。
451 0
|
机器学习/深度学习 人工智能 自然语言处理
【Pytorch神经网络理论篇】 38 Transformers:安装说明+应用结构+AutoModel类
transfomersF中包括自然语言理解和自然语言生成两大类任务,提供了先进的通用架构,其中有超2个预训练模型(细分为100多种语言的版本)。
1630 0
|
机器学习/深度学习 人工智能 PyTorch
【Pytorch神经网络理论篇】 01 Pytorch快速上手(一)概述+张量
Pytorch是基于Torch之上的python包,在底层主要通过张量的形式进行计算,Pytorch中的张量表示为同一数据类型的多位橘子。
172 0
|
机器学习/深度学习 PyTorch 算法框架/工具
【9】一些经典CNN结构的pytorch实现
【9】一些经典CNN结构的pytorch实现
362 0
【9】一些经典CNN结构的pytorch实现

热门文章

最新文章

下一篇
开通oss服务