PyG学习笔记1-INTRODUCTION BY EXAMPLE(二)

简介: PyG学习笔记1-INTRODUCTION BY EXAMPLE(二)

自定义 Dataset


尽管 PyG 已经包含许多有用的数据集,我们也可以通过继承torch_geometric.data.Dataset使用自己的数据集。提供 2 种不同的Dataset:


InMemoryDataset:使用这个Dataset会一次性把数据全部加载到内存中。

Dataset: 使用这个Dataset每次加载一个数据到内存中,比较常用。

我们需要在自定义的Dataset的初始化方法中传入数据存放的路径,然后 PyG 会在这个路径下再划分 2 个文件夹:


raw_dir: 存放原始数据的路径,一般是 csv、mat 等格式

processed_dir: 存放处理后的数据,一般是 pt 格式 ( 由我们重写process()方法实现)。


Transforms


transforms在计算机视觉领域是一种很常见的数据增强。PyG 有自己的transforms,输出是Data类型,输出也是Data类型。可以使用torch_geometric.transforms.Compose封装一系列的transforms。我们以 ShapeNet 数据集 (包含 17000 个 point clouds,每个 point 分类为 16 个类别的其中一个) 为例,我们可以使用transforms从 point clouds 生成最近邻图:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])


还可以通过transform在一定范围内随机平移每个点,增加坐标上的扰动,做数据增强:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6),
                    transform=T.RandomTranslate(0.01))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])


模型训练


这里只是展示一个简单的 GCN 模型构造和训练过程,没有用到Dataset和DataLoader。


我们将使用一个简单的 GCN 层,并在 Cora 数据集上实验。有关 GCN 的更多内容,请查看**这篇博客**。


我们首先加载数据集:

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')


然后定义 2 层的 GCN:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


然后训练 200 个 epochs:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()


最后在测试集上验证了模型的准确率:

model.eval()
_, pred = model(data).max(dim=1)
correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))


参考链接


PyG Documentation — pytorch_geometric 2.0.2 documentation (pytorch-geometric.readthedocs.io)

目录
相关文章
|
机器学习/深度学习 数据可视化 数据挖掘
PyTorch Geometric (PyG) 入门教程
PyTorch Geometric是PyTorch1的几何图形学深度学习扩展库。本文旨在通过介绍PyTorch Geometric(PyG)中常用的方法等内容,为新手提供一个PyG的入门教程。
PyTorch Geometric (PyG) 入门教程
|
存储 机器学习/深度学习 PyTorch
PyG学习笔记1-INTRODUCTION BY EXAMPLE(一)
PyG学习笔记1-INTRODUCTION BY EXAMPLE(一)
234 0
PyG学习笔记1-INTRODUCTION BY EXAMPLE(一)
|
TensorFlow 算法框架/工具
《A beginner introduction to TensorFlow (Part-1)》电子版地址
A beginner introduction to TensorFlow (Part-1)
80 0
《A beginner introduction to TensorFlow (Part-1)》电子版地址
|
机器学习/深度学习 运维 算法
an introduction|学习笔记
快速学习 an introduction
90 0
an introduction|学习笔记
|
机器学习/深度学习 人工智能 搜索推荐
【推荐系统论文精读系列】(十五)--Examples-Rules Guided Deep Neural Network for Makeup Recommendation
在本文中,我们考虑了一个全自动补妆推荐系统,并提出了一种新的例子-规则引导的深度神经网络方法。该框架由三个阶段组成。首先,将与化妆相关的面部特征进行结构化编码。其次,这些面部特征被输入到示例中——规则引导的深度神经推荐模型,该模型将Before-After图像和化妆师知识两两结合使用。
118 0
【推荐系统论文精读系列】(十五)--Examples-Rules Guided Deep Neural Network for Makeup Recommendation
|
机器学习/深度学习 负载均衡 搜索推荐
【推荐系统论文精读系列】(十六)--Locally Connected Deep Learning Framework for Industrial-scale Recommender Systems
在这项工作中,我们提出了一个局部连接的深度学习框架推荐系统,该框架将DNN的模型复杂性降低了几个数量级。我们利用Wide& Deep模型的思想进一步扩展了框架。实验表明,该方法能在较短的运行时间内取得较好的效果。
107 0
【推荐系统论文精读系列】(十六)--Locally Connected Deep Learning Framework for Industrial-scale Recommender Systems
|
机器学习/深度学习 文字识别 并行计算
CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章(三)
CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章
CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章(三)
|
机器学习/深度学习 人工智能 编解码
CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章(一)
CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章
CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章(一)
|
机器学习/深度学习 数据挖掘 计算机视觉
CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章(二)
CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章
CV:翻译并解读2019《A Survey of the Recent Architectures of Deep Convolutional Neural Networks》第一章~第三章(二)
|
数据挖掘 TensorFlow 算法框架/工具
Introduction to the Keras Tuner
The Keras Tuner is a library that helps you pick the optimal set of hyperparameters for your TensorFlow program. The process of selecting the right set of hyperparameters for your machine learning (ML) application is called hyperparameter tuning or hypertuning.
224 0