Spektral:使用TF2实现经典GNN的开源库

简介: Spektral:使用TF2实现经典GNN的开源库

简介


Spektral工具还发表了论文:

《Graph Neural Networks in TensorFlow and Keras with Spektral》

https://arxiv.org/abs/2006.12138

102.png


github地址:https://github.com/danielegrattarola/spektral/

在本文中,我们介绍了 Spektral,这是一个开源 Python 库,用于使用 TensorFlow 和 Keras 应用程序编程接口构建图神经网络。Spektral 实现了大量的图深度学习方法,包括消息传递和池化运算符,以及用于处理图和加载流行基准数据集的实用程序。这个库的目的是为创建图神经网络提供基本的构建块,重点是 Keras 所基于的用户友好性和快速原型设计的指导原则。因此,Spektral 适合绝对的初学者和专业的深度学习从业者。


主要网络


Spektral 实现了一些主流的图深度学习层,包括:


安装


pip安装:

pip install spektral


源码安装:

git clone https://github.com/danielegrattarola/spektral.git
cd spektral
python setup.py install  # Or 'pip install .'


Spektral实现GCN


对于TF爱好者很友好:

import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
from spektral.data.loaders import SingleLoader
from spektral.datasets.citation import Citation
from spektral.layers import GCNConv
from spektral.models.gcn import GCN
from spektral.transforms import AdjToSpTensor, LayerPreprocess
learning_rate = 1e-2
seed = 0
epochs = 200
patience = 10
data = "cora"
tf.random.set_seed(seed=seed)  # make weight initialization reproducible
# Load data
dataset = Citation(
    data, normalize_x=True, transforms=[LayerPreprocess(GCNConv), AdjToSpTensor()]
)
# We convert the binary masks to sample weights so that we can compute the
# average loss over the nodes (following original implementation by
# Kipf & Welling)
def mask_to_weights(mask):
    return mask.astype(np.float32) / np.count_nonzero(mask)
weights_tr, weights_va, weights_te = (
    mask_to_weights(mask)
    for mask in (dataset.mask_tr, dataset.mask_va, dataset.mask_te)
)
model = GCN(n_labels=dataset.n_labels, n_input_channels=dataset.n_node_features)
model.compile(
    optimizer=Adam(learning_rate),
    loss=CategoricalCrossentropy(reduction="sum"),
    weighted_metrics=["acc"],
)
# Train model
loader_tr = SingleLoader(dataset, sample_weights=weights_tr)
loader_va = SingleLoader(dataset, sample_weights=weights_va)
model.fit(
    loader_tr.load(),
    steps_per_epoch=loader_tr.steps_per_epoch,
    validation_data=loader_va.load(),
    validation_steps=loader_va.steps_per_epoch,
    epochs=epochs,
    callbacks=[EarlyStopping(patience=patience, restore_best_weights=True)],
)
# Evaluate model
print("Evaluating model.")
loader_te = SingleLoader(dataset, sample_weights=weights_te)
eval_results = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch)
print("Done.\n" "Test loss: {}\n" "Test accuracy: {}".format(*eval_results))


相关文章
|
9月前
|
机器学习/深度学习 算法 PyTorch
python手把手搭建图像多分类神经网络-代码教程(手动搭建残差网络、mobileNET)
python手把手搭建图像多分类神经网络-代码教程(手动搭建残差网络、mobileNET)
|
机器学习/深度学习 人工智能 自然语言处理
【Pytorch神经网络理论篇】 38 Transformers:安装说明+应用结构+AutoModel类
transfomersF中包括自然语言理解和自然语言生成两大类任务,提供了先进的通用架构,其中有超2个预训练模型(细分为100多种语言的版本)。
1652 0
|
机器学习/深度学习 PyTorch 算法框架/工具
【9】一些经典CNN结构的pytorch实现
【9】一些经典CNN结构的pytorch实现
374 0
【9】一些经典CNN结构的pytorch实现
|
机器学习/深度学习 TensorFlow 算法框架/工具
TensorFlow入门(五)多层 LSTM 通俗易懂版
前言:其实之前就已经用过 LSTM 了,是在深度学习框架 keras 上直接用的,但是到现在对LSTM详细的网络结构还是不了解,心里牵挂着难受呀!今天看了 tensorflow 文档上面推荐的这篇博文,看完这后,焕然大悟,对 LSTM 的结构理解基本上没有太大问题。
5410 0
|
机器学习/深度学习 资源调度 算法
DL框架之MXNet :神经网络算法简介之MXNet 常见使用方法总结(神经网络DNN、CNN、RNN算法)之详细攻略(个人使用)
DL框架之MXNet :神经网络算法简介之MXNet 常见使用方法总结(神经网络DNN、CNN、RNN算法)之详细攻略(个人使用)
|
机器学习/深度学习 TensorFlow 算法框架/工具
TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略
TF之CNN:Tensorflow构建卷积神经网络CNN的简介、使用方法、应用之详细攻略
|
机器学习/深度学习 自然语言处理 测试技术
TF之LSTM:基于Tensorflow框架采用PTB数据集建立LSTM网络的自然语言建模
TF之LSTM:基于Tensorflow框架采用PTB数据集建立LSTM网络的自然语言建模
|
机器学习/深度学习 TensorFlow API
一步一步学用Tensorflow构建卷积神经网络
本文主要和大家分享如何使用Tensorflow从头开始构建和训练卷积神经网络。这样就可以将这个知识作为一个构建块来创造有趣的深度学习应用程序了。
19296 0
|
机器学习/深度学习 自然语言处理 TensorFlow