RGCN模型成功运行案例

简介: # 创建模型in_feats = 3hid_feats = 4out_feats = 2rel_num = 4model = RGCN(in_feats, hid_feats, out_feats, rel_num)# 随机生成特征features = torch.randn((10, 3))# 计算输出output = model(g, features, rel_type)print(output)
import torch
from torch import nn
import dgl
from dgl.nn.pytorch import RelGraphConv
# 定义图结构
edges = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edges_src = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edges_dst = torch.tensor([1, 2, 3, 0, 4, 5, 6, 7, 8, 9])
rel_type = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3])
graph = (edges_src, edges_dst)
# 将元组图结构转换为DGLGraph对象
g = dgl.graph(graph)
# 定义模型
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_num):
        super(RGCN, self).__init__()
        self.conv1 = RelGraphConv(in_feats, hid_feats, rel_num)
        self.conv2 = RelGraphConv(hid_feats, out_feats, rel_num)
    def forward(self, g, feats, rel_type):
        h = self.conv1(g, feats, rel_type)
        h = torch.relu(h)
        h = self.conv2(g, h, rel_type)
        return h
# 创建模型
in_feats = 3
hid_feats = 4
out_feats = 2
rel_num = 4
model = RGCN(in_feats, hid_feats, out_feats, rel_num)
# 随机生成特征
features = torch.randn((10, 3))
# 计算输出
output = model(g, features, rel_type)
print(output)

上面的代码主要是实现了一个基于关系图卷积网络(RGCN)的模型。

代码实现的详细解释如下:

  1. 首先,使用 PyTorch 定义了一个 RGCN 模型,并通过构造函数中的 in_feats, hid_feats, out_featsrel_num 来指定输入特征维度,隐藏层维度,输出特征维度和关系类型数量。


class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_num):
        super(RGCN, self).__init__()
        self.conv1 = RelGraphConv(in_feats, hid_feats, rel_num)
        self.conv2 = RelGraphConv(hid_feats, out_feats, rel_num)
  1. 在模型的前向传递函数中,将输入特征 feats 和关系类型 rel_type 通过 RelGraphConv 进行关系图卷积,同时使用 ReLU 激活函数对隐藏层的特征进行非线性变换。


def forward(self, g, feats, rel_type):
    h = self.conv1(g, feats, rel_type)
    h = torch.relu(h)
    h = self.conv2(g, h, rel_type)
    return h
  1. 接着,使用 PyTorch 创建了一个大小为 10x3 的张量 features 作为输入特征。
features = torch.randn((10, 3))
  1. 将定义好的元组图结构 graph 转换为 DGLGraph 对象 g


g = dgl.graph(graph)
  1. 最后,将 g, featuresrel_type 作为输入,通过调用模型的前向传递函数 forward 计算输出特征 output,并将其打印出来。
output = model(g, features, rel_type)
print(output)

在计算输出时,模型会将输入的元组图结构 graph 转换为 DGLGraph 对象 gg 的结构与 graph 是相同的,但是由于它是 DGLGraph 对象,因此可以使用 DGL 提供的一系列 API 来进行图的操作和计算,例如采样、聚合和传播等。在 RGCN 中,DGL 提供的关系图卷积层(RelGraphConv)用于实现图卷积计算。


----------------介绍整个代码实现了什么功能?


这段代码实现了一个基于关系图卷积神经网络(R-GCN)的节点分类任务,具体实现过程如下:

  1. 定义图结构:使用 PyTorch 张量定义了一张包含 10 条边的有向图,并指定每条边的起点、终点和关系类型。
  2. 将元组图结构转换为 DGLGraph 对象:使用 DGL 库提供的函数将元组图结构转换为 DGLGraph 对象,以便后续在其上运行图卷积神经网络。
  3. 定义 R-GCN 模型:使用 PyTorch 搭建了一个包含两个 R-GCN 层的模型,每个层都使用了相同的关系类型数,并指定了输入特征维度、隐藏层维度和输出特征维度。
  4. 随机生成特征:使用 PyTorch 生成了一个形状为 (10, 3) 的张量作为节点特征。
  5. 计算输出:将定义好的模型、DGLGraph 对象和节点特征输入模型,计算输出。输出的形状为 (10, 2),即 10 个节点对应的分类概率。最后将输出打印出来。


解释下面的含义:


edges = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edges_src = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edges_dst = torch.tensor([1, 2, 3, 0, 4, 5, 6, 7, 8, 9])
rel_type = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3])

这里定义了一个有向图,其中图中包含10个节点,通过分别指定edges_src和edges_dst中的源节点和目标节点来定义每条边。rel_type是用来指定每个边的关系类型,共有4种类型,分别用0、1、2、3表示。具体来说,这里定义了以下10条有向边:


  • 0 -> 1 (关系类型为0)
  • 1 -> 2 (关系类型为0)
  • 2 -> 3 (关系类型为0)
  • 3 -> 0 (关系类型为1)
  • 4 -> 1 (关系类型为1)
  • 5 -> 2 (关系类型为1)
  • 6 -> 3 (关系类型为2)
  • 7 -> 4 (关系类型为2)
  • 8 -> 5 (关系类型为2)
  • 9 -> 6 (关系类型为3)


上述代码定义了一个图结构,其中包含10个节点和10条边。变量含义如下:

  • edges: 表示图中10条边的编号,取值为 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
  • edges_src: 表示每条边的源节点编号,取值为 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],与edges一一对应。
  • edges_dst: 表示每条边的目标节点编号,取值为 [1, 2, 3, 0, 4, 5, 6, 7, 8, 9],与edges一一对应。
  • rel_type: 表示每条边的关系类型,取值为 [0, 0, 0, 1, 1, 1, 2, 2, 2, 3],与edges一一对应。
  • graph: 表示由节点和边构成的图结构,它是一个元组,包含了两个张量 edges_srcedges_dst


最后这段代码实现了以下功能:

# 创建模型
in_feats = 3
hid_feats = 4
out_feats = 2
rel_num = 4
model = RGCN(in_feats, hid_feats, out_feats, rel_num)
# 随机生成特征
features = torch.randn((10, 3))
# 计算输出
output = model(g, features, rel_type)
print(output)
  1. 创建一个RGCN模型对象 model,该模型具有 in_feats 个输入特征、hid_feats 个隐藏特征、out_feats 个输出特征和 rel_num 种不同的关系类型。
  2. 使用 torch.randn 随机生成一个 $10 \times 3$ 大小的特征矩阵 features
  3. 将特征矩阵 features、关系类型张量 rel_type 和转换后的DGL图 g 作为输入,通过 model 模型计算输出特征矩阵 output
  4. 输出 output


输出

tensor([[ 0.5757, -0.0934],
        [ 0.9615,  1.4563],
        [-3.5925, -0.7869],
        [ 1.2882, -0.3457],
        [ 2.5402, -0.2980],
        [ 0.1554,  0.6599],
        [ 3.8173,  2.2265],
        [ 0.8300,  1.1929],
        [ 2.6410,  3.7959],
        [-1.5862, -0.5873]], grad_fn=<AddBackward0>)


























目录
相关文章
|
机器学习/深度学习 PyTorch 算法框架/工具
RGCN的torch简单案例
RGCN 是指 Relational Graph Convolutional Network,是一种基于图卷积神经网络(GCN)的模型。与传统的 GCN 不同的是,RGCN 可以处理具有多种关系(边)类型的图数据,从而更好地模拟现实世界中的实体和它们之间的复杂关系。 RGCN 可以用于多种任务,例如知识图谱推理、社交网络分析、药物发现等。以下是一个以知识图谱推理为例的应用场景: 假设我们有一个知识图谱,其中包含一些实体(如人、物、地点)以及它们之间的关系(如出生于、居住在、工作于)。图谱可以表示为一个二元组 (E, R),其中 E 表示实体的集合,R 表示关系的集合,每个关系 r ∈ R
2682 0
|
缓存 PyTorch 数据处理
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
1570 0
基于Pytorch的PyTorch Geometric(PYG)库构造个人数据集
|
3月前
|
人工智能 弹性计算 安全
OpenClaw(龙虾)一键秒级部署指南,两步解锁专属AI助理!
本文将为大家分享OpenClaw(龙虾)一键秒级部署指南,只需两步,即可轻松拥有专属AI助理!
772 5
|
5月前
|
机器学习/深度学习 监控 数据可视化
基于YOLOv8的停车场空车位目标检测项目|完整源码数据集+PyQt5界面+完整训练流程+开箱即用!
本项目基于YOLOv8实现停车场空车位智能检测,支持Occupied/Vacant双类别识别,集成PyQt5图形界面,兼容图片、视频、摄像头等多源输入。提供完整源码、标注数据集、预训练权重及详细教程,开箱即用,适用于毕设、科研与智慧停车原型开发。
基于YOLOv8的停车场空车位目标检测项目|完整源码数据集+PyQt5界面+完整训练流程+开箱即用!
|
14天前
|
人工智能 芯片 开发者
倒计时2天!直击2026阿里云峰会:从底层芯片到Agentic OS,揭秘AI全栈新范式
2026阿里云峰会将于5月20–21日在杭州西子宾馆举办,聚焦“Agentic Cloud”,全景展示芯片、大模型到推理平台的全栈升级,深度探讨智能体(Agent)时代新范式,为开发者提供前瞻技术洞察与实践机遇。
|
编译器
overleaf 参考文献引用,创建引用目录.bib文件,在文档中引用参考文献,生成参考文献列表
overleaf 参考文献引用,创建引用目录.bib文件,在文档中引用参考文献,生成参考文献列表
12053 0
|
12月前
|
机器学习/深度学习 自然语言处理 安全
ACL 2025 | GALLa:用图结构增强代码大模型,让代码理解更精准!
通过级联多模态架构将代码结构图对齐到大模型表征中
848 69
|
10月前
|
前端开发 Java API
2025 年 Java 全栈从环境搭建到项目上线实操全流程指南:Java 全栈最新实操指南(2025 版)
本指南涵盖2025年Java全栈开发核心技术,从JDK 21环境搭建、Spring Boot 3.3实战、React前端集成到Docker容器化部署,结合最新特性与实操流程,助力构建高效企业级应用。
2959 1
|
存储 JavaScript 前端开发
使用Vue.js构建交互式前端界面的技术探索
【5月更文挑战第20天】Vue.js是一款渐进式JavaScript框架,擅长构建交互式前端界面。其核心特性包括响应式数据绑定、组件化开发、指令系统和虚拟DOM,简化开发并提升性能。通过Vue CLI创建项目,拆分组件,结合数据绑定和事件处理实现交互,使用Vue Router管理路由,Vuex进行状态管理,能高效构建现代Web应用。
|
弹性计算 负载均衡 网络协议
slb健康检查
【9月更文挑战第2天】
662 10

热门文章

最新文章