使用Pytorch Geometric 进行链接预测代码示例

简介: 该代码示例使用PyTorch和`torch_geometric`库实现了一个简单的图卷积网络(GCN)模型,处理Cora数据集。模型包含两层GCNConv,每层后跟ReLU激活和dropout。模型在训练集上进行200轮训练,使用Adam优化器和交叉熵损失函数。最后,计算并打印测试集的准确性。
import torch
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid

# 加载数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

# 定义图卷积网络模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = torch.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return torch.log_softmax(x, dim=1)

# 初始化模型、优化器和损失函数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()

# 训练模型
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

# 测试模型
model.eval()
_, pred = model(data).max(dim=1)
correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / data.test_mask.sum().item()
print('Accuracy: {:.4f}'.format(acc))
目录
相关文章
|
8月前
|
机器学习/深度学习 JavaScript PyTorch
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
生成对抗网络(GAN)的训练效果高度依赖于损失函数的选择。本文介绍了经典GAN损失函数理论,并用PyTorch实现多种变体,包括原始GAN、LS-GAN、WGAN及WGAN-GP等。通过分析其原理与优劣,如LS-GAN提升训练稳定性、WGAN-GP改善图像质量,展示了不同场景下损失函数的设计思路。代码实现覆盖生成器与判别器的核心逻辑,为实际应用提供了重要参考。未来可探索组合优化与自适应设计以提升性能。
685 7
9个主流GAN损失函数的数学原理和Pytorch代码实现:从经典模型到现代变体
|
3月前
|
PyTorch 算法框架/工具 异构计算
PyTorch 2.0性能优化实战:4种常见代码错误严重拖慢模型
我们将深入探讨图中断(graph breaks)和多图问题对性能的负面影响,并分析PyTorch模型开发中应当避免的常见错误模式。
252 9
|
存储 物联网 PyTorch
基于PyTorch的大语言模型微调指南:Torchtune完整教程与代码示例
**Torchtune**是由PyTorch团队开发的一个专门用于LLM微调的库。它旨在简化LLM的微调流程,提供了一系列高级API和预置的最佳实践
634 59
基于PyTorch的大语言模型微调指南:Torchtune完整教程与代码示例
|
4月前
|
机器学习/深度学习 数据可视化 PyTorch
Flow Matching生成模型:从理论基础到Pytorch代码实现
本文将系统阐述Flow Matching的完整实现过程,包括数学理论推导、模型架构设计、训练流程构建以及速度场学习等关键组件。通过本文的学习,读者将掌握Flow Matching的核心原理,获得一个完整的PyTorch实现,并对生成模型在噪声调度和分数函数之外的发展方向有更深入的理解。
1886 0
Flow Matching生成模型:从理论基础到Pytorch代码实现
|
5月前
|
机器学习/深度学习 PyTorch 算法框架/工具
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
本文将深入探讨L1、L2和ElasticNet正则化技术,重点关注其在PyTorch框架中的具体实现。关于这些技术的理论基础,建议读者参考相关理论文献以获得更深入的理解。
176 4
提升模型泛化能力:PyTorch的L1、L2、ElasticNet正则化技术深度解析与代码实现
|
7月前
|
机器学习/深度学习 数据可视化 机器人
比扩散策略更高效的生成模型:流匹配的理论基础与Pytorch代码实现
扩散模型和流匹配是生成高分辨率数据(如图像和机器人轨迹)的先进技术。扩散模型通过逐步去噪生成数据,其代表应用Stable Diffusion已扩展至机器人学领域形成“扩散策略”。流匹配作为更通用的方法,通过学习时间依赖的速度场将噪声转化为目标分布,适用于图像生成和机器人轨迹生成,且通常以较少资源实现更快生成。 本文深入解析流匹配在图像生成中的应用,核心思想是将图像视为随机变量的实现,并通过速度场将源分布转换为目标分布。文中提供了一维模型训练实例,展示了如何用神经网络学习速度场,以及使用最大均值差异(MMD)改进训练效果。与扩散模型相比,流匹配结构简单,资源需求低,适合多模态分布生成。
559 13
比扩散策略更高效的生成模型:流匹配的理论基础与Pytorch代码实现
|
7月前
|
机器学习/深度学习 编解码 PyTorch
从零实现基于扩散模型的文本到视频生成系统:技术详解与Pytorch代码实现
本文介绍了一种基于扩散模型的文本到视频生成系统,详细展示了模型架构、训练流程及生成效果。通过3D U-Net结构和多头注意力机制,模型能够根据文本提示生成高质量视频。
313 1
从零实现基于扩散模型的文本到视频生成系统:技术详解与Pytorch代码实现
|
机器学习/深度学习 PyTorch 算法框架/工具
CNN中的注意力机制综合指南:从理论到Pytorch代码实现
注意力机制已成为深度学习模型的关键组件,尤其在卷积神经网络(CNN)中发挥了重要作用。通过使模型关注输入数据中最相关的部分,注意力机制显著提升了CNN在图像分类、目标检测和语义分割等任务中的表现。本文将详细介绍CNN中的注意力机制,包括其基本概念、不同类型(如通道注意力、空间注意力和混合注意力)以及实际实现方法。此外,还将探讨注意力机制在多个计算机视觉任务中的应用效果及其面临的挑战。无论是图像分类还是医学图像分析,注意力机制都能显著提升模型性能,并在不断发展的深度学习领域中扮演重要角色。
669 10
|
9月前
|
机器学习/深度学习 存储 算法
近端策略优化(PPO)算法的理论基础与PyTorch代码详解
近端策略优化(PPO)是深度强化学习中高效的策略优化方法,广泛应用于大语言模型的RLHF训练。PPO通过引入策略更新约束机制,平衡了更新幅度,提升了训练稳定性。其核心思想是在优势演员-评论家方法的基础上,采用裁剪和非裁剪项组成的替代目标函数,限制策略比率在[1-ϵ, 1+ϵ]区间内,防止过大的策略更新。本文详细探讨了PPO的基本原理、损失函数设计及PyTorch实现流程,提供了完整的代码示例。
4110 10
近端策略优化(PPO)算法的理论基础与PyTorch代码详解
|
9月前
|
存储 机器学习/深度学习 PyTorch
PyTorch Profiler 性能优化示例:定位 TorchMetrics 收集瓶颈,提高 GPU 利用率
本文探讨了机器学习项目中指标收集对训练性能的影响,特别是如何通过简单实现引入不必要的CPU-GPU同步事件,导致训练时间增加约10%。使用TorchMetrics库和PyTorch Profiler工具,文章详细分析了性能瓶颈的根源,并提出了多项优化措施
438 1
PyTorch Profiler 性能优化示例:定位 TorchMetrics 收集瓶颈,提高 GPU 利用率

热门文章

最新文章

推荐镜像

更多