PyG学习笔记1-INTRODUCTION BY EXAMPLE(二)

简介: PyG学习笔记1-INTRODUCTION BY EXAMPLE(二)

自定义 Dataset


尽管 PyG 已经包含许多有用的数据集,我们也可以通过继承torch_geometric.data.Dataset使用自己的数据集。提供 2 种不同的Dataset:


InMemoryDataset:使用这个Dataset会一次性把数据全部加载到内存中。

Dataset: 使用这个Dataset每次加载一个数据到内存中,比较常用。

我们需要在自定义的Dataset的初始化方法中传入数据存放的路径,然后 PyG 会在这个路径下再划分 2 个文件夹:


raw_dir: 存放原始数据的路径,一般是 csv、mat 等格式

processed_dir: 存放处理后的数据,一般是 pt 格式 ( 由我们重写process()方法实现)。


Transforms


transforms在计算机视觉领域是一种很常见的数据增强。PyG 有自己的transforms,输出是Data类型,输出也是Data类型。可以使用torch_geometric.transforms.Compose封装一系列的transforms。我们以 ShapeNet 数据集 (包含 17000 个 point clouds,每个 point 分类为 16 个类别的其中一个) 为例,我们可以使用transforms从 point clouds 生成最近邻图:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])


还可以通过transform在一定范围内随机平移每个点,增加坐标上的扰动,做数据增强:

import torch_geometric.transforms as T
from torch_geometric.datasets import ShapeNet
dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],
                    pre_transform=T.KNNGraph(k=6),
                    transform=T.RandomTranslate(0.01))
# dataset[0]: Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])


模型训练


这里只是展示一个简单的 GCN 模型构造和训练过程,没有用到Dataset和DataLoader。


我们将使用一个简单的 GCN 层,并在 Cora 数据集上实验。有关 GCN 的更多内容,请查看**这篇博客**。


我们首先加载数据集:

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='/tmp/Cora', name='Cora')


然后定义 2 层的 GCN:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


然后训练 200 个 epochs:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()


最后在测试集上验证了模型的准确率:

model.eval()
_, pred = model(data).max(dim=1)
correct = float (pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))


参考链接


PyG Documentation — pytorch_geometric 2.0.2 documentation (pytorch-geometric.readthedocs.io)

目录
相关文章
|
机器学习/深度学习 数据可视化 数据挖掘
PyTorch Geometric (PyG) 入门教程
PyTorch Geometric是PyTorch1的几何图形学深度学习扩展库。本文旨在通过介绍PyTorch Geometric(PyG)中常用的方法等内容,为新手提供一个PyG的入门教程。
PyTorch Geometric (PyG) 入门教程
|
存储 机器学习/深度学习 PyTorch
PyG学习笔记1-INTRODUCTION BY EXAMPLE(一)
PyG学习笔记1-INTRODUCTION BY EXAMPLE(一)
343 0
PyG学习笔记1-INTRODUCTION BY EXAMPLE(一)
|
机器学习/深度学习 传感器 自然语言处理
论文笔记:SpectralFormer Rethinking Hyperspectral Image Classification With Transformers_外文翻译
 高光谱(HS)图像具有近似连续的光谱信息,能够通过捕获细微的光谱差异来精确识别物质。卷积神经网络(CNNs)由于具有良好的局部上下文建模能力,在HS图像分类中是一种强有力的特征提取器。然而,由于其固有的网络骨干网的限制,CNN不能很好地挖掘和表示谱特征的序列属性。
218 0
|
机器学习/深度学习 算法 PyTorch
从零开始学Pytorch(八)之Modern CNN
从零开始学Pytorch(八)之Modern CNN
从零开始学Pytorch(八)之Modern CNN
|
机器学习/深度学习 JSON 自然语言处理
基于Vision Transformers的文档理解简介
文档理解是从pdf、图像和Word文档中提取关键信息的技术。这篇文章的目标是提供一个文档理解模型的概述。
251 0
|
算法 PyTorch 算法框架/工具
【论文阅读及复现】(2017)Densely Connected Convolutional Networks + Pytorch代码实现
- 最近的工作表明,如果卷积网络在靠近输入的层和靠近输出的层之间包含较短的连接,则它们可以更深、更准确和更有效地训练。 - 在本文中,我们接受了这一观察并介绍了密集卷积网络(DenseNet),它以前馈方式将每一层连接到其他每一层。具有 L 层的传统卷积网络有 L 个连接——每层与其后续层之间有一个连接——我们的网络有 L*(L+1) /2 个直接连接。 F 或每一层,所有前面的层的特征图被用作输入,它自己的特征图被用作所有后续层的输入。 - DenseNets 有几个引人注目的优势:它们缓解了梯度消失问题,加强了特征传播,鼓励特征重用,并大大减少了参数的数量。 - 我们在四个竞争激烈的对象
188 0
【论文阅读及复现】(2017)Densely Connected Convolutional Networks + Pytorch代码实现
|
机器学习/深度学习 运维 算法
an introduction|学习笔记
快速学习 an introduction
an introduction|学习笔记
|
机器学习/深度学习 资源调度 并行计算
李宏毅2021春季机器学习课程视频笔记1:Introduction, Colab & PyTorch Tutorials, HW1
李宏毅2021春季机器学习课程视频笔记1:Introduction, Colab & PyTorch Tutorials, HW1
李宏毅2021春季机器学习课程视频笔记1:Introduction, Colab & PyTorch Tutorials, HW1
|
机器学习/深度学习 人工智能 算法
Bag of Tricks for Efficient Text Classification 论文阅读及实战
Bag of Tricks for Efficient Text Classification 论文阅读及实战
349 0
Bag of Tricks for Efficient Text Classification 论文阅读及实战
|
数据挖掘 TensorFlow 算法框架/工具
Introduction to the Keras Tuner
The Keras Tuner is a library that helps you pick the optimal set of hyperparameters for your TensorFlow program. The process of selecting the right set of hyperparameters for your machine learning (ML) application is called hyperparameter tuning or hypertuning.
271 0

热门文章

最新文章