PyG异质图神经网络NotImplementedError问题

简介: PyG异质图神经网络NotImplementedError问题

以PyG官方的数据集和示例代码来复现一下这个问题:


import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
dataset = OGB_MAG(root='files/pyg_data', preprocess='metapath2vec')
data = dataset[0]
print(data)
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)
print(out)


输出信息:


HeteroData(
  paper={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  author={ x=[1134649, 128] },
  institution={ x=[8740, 128] },
  field_of_study={ x=[59965, 128] },
  (author, affiliated_with, institution)={ edge_index=[2, 1043998] },
  (author, writes, paper)={ edge_index=[2, 7145660] },
  (paper, cites, paper)={ edge_index=[2, 5416271] },
  (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] }
)
my_env/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py:145: UserWarning: There exist node types ({'author'}) whose representations do not get updated during message passing as they do not occur as destination type in any edge type. This may lead to unexpected behaviour.
  warnings.warn(
Traceback (most recent call last):
  File "try2.py", line 25, in <module>
    model = to_hetero(model, data.metadata(), aggr='sum')
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 118, in to_hetero
    return transformer.transform()
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/fx.py", line 157, in transform
    getattr(self, op)(node, node.target, node.name)
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 294, in call_method
    args, kwargs = self.map_args_kwargs(node, key)
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 397, in map_args_kwargs
    args = tuple(_recurse(v) for v in node.args)
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 397, in <genexpr>
    args = tuple(_recurse(v) for v in node.args)
  File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 387, in _recurse
    raise NotImplementedError
NotImplementedError


可以很容易地看出来,这是由于有一种节点没有入边产生的问题。

解决方案就是使所有节点都有入边。如:


import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.nn import SAGEConv, to_hetero
dataset = OGB_MAG(root='files/pyg_data', preprocess='metapath2vec',transform=T.ToUndirected())
data = dataset[0]
print(data)
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
model = GNN(hidden_channels=64, out_channels=dataset.num_classes)
model = to_hetero(model, data.metadata(), aggr='sum')
with torch.no_grad():  # Initialize lazy modules.
    out = model(data.x_dict, data.edge_index_dict)
print(out)


将异质图转换为无向图,这样就能得到正常的输出结果:


HeteroData(
  paper={
    x=[736389, 128],
    year=[736389],
    y=[736389],
    train_mask=[736389],
    val_mask=[736389],
    test_mask=[736389]
  },
  author={ x=[1134649, 128] },
  institution={ x=[8740, 128] },
  field_of_study={ x=[59965, 128] },
  (author, affiliated_with, institution)={ edge_index=[2, 1043998] },
  (author, writes, paper)={ edge_index=[2, 7145660] },
  (paper, cites, paper)={ edge_index=[2, 10792672] },
  (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] },
  (institution, rev_affiliated_with, author)={ edge_index=[2, 1043998] },
  (paper, rev_writes, author)={ edge_index=[2, 7145660] },
  (field_of_study, rev_has_topic, paper)={ edge_index=[2, 7505078] }
)
{'paper': tensor([[-0.8212, -0.2630, -0.7286,  ...,  1.1904,  0.1617, -0.5388],
        [-1.2484, -0.3707, -1.0336,  ...,  0.9618, -0.0373, -0.1125],
        [-0.5375,  0.0357, -0.6772,  ...,  1.2185,  0.2292, -0.2130],
        ...,
        [-0.9934, -0.2688, -0.9547,  ...,  1.3144,  0.1519, -0.2015],
        [-1.4711, -0.6607, -0.7509,  ...,  2.3383,  0.6815, -1.0679],
        [-0.4352, -0.4255, -0.6907,  ...,  1.1532,  0.1152, -0.9703]]), 'author': tensor([[-0.2782,  0.1771,  0.4187,  ..., -0.5233, -0.2969,  0.2438],
        [-0.4543,  0.1019,  0.1637,  ..., -0.7748, -0.2809,  0.2598],
        [-0.1613, -0.0481, -0.2491,  ..., -0.6227, -0.4217,  0.1335],
        ...,
        [-0.4908,  0.2382,  0.2973,  ..., -0.7266, -0.2486,  0.6449],
        [-0.2819,  0.0125,  0.9843,  ..., -1.9652, -0.4280, -0.4842],
        [-0.4236, -0.1222,  1.0246,  ..., -2.0615, -0.3246, -0.1771]]), 'institution': tensor([[ 0.3911, -1.3527, -0.6624,  ...,  0.2732,  0.5270,  0.5756],
        [ 0.1512, -0.6687, -0.6516,  ...,  0.1482,  0.2535,  0.1935],
        [ 0.1933, -1.1643, -0.4936,  ...,  0.5382,  0.3407,  0.2199],
        ...,
        [ 0.1489, -0.3021, -0.3390,  ...,  0.2690,  0.1571, -0.0781],
        [ 0.1855, -0.4848, -0.3205,  ...,  0.4728,  0.0659,  0.1500],
        [ 0.1724, -0.0682, -0.0894,  ...,  0.1189,  0.1230, -0.2249]]), 'field_of_study': tensor([[ 0.1929, -0.5402, -0.5714,  ..., -0.4296,  0.4376, -0.0660],
        [-0.2281,  0.0773, -0.0486,  ..., -0.0544, -0.2894,  0.2706],
        [-0.2798, -0.1967, -0.3376,  ..., -0.3098, -0.1610,  0.1120],
        ...,
        [ 0.0775, -0.5927, -0.6084,  ..., -0.3190,  0.2483, -0.1418],
        [ 0.0286, -0.7393, -0.6629,  ..., -0.4745,  0.8461, -0.1554],
        [-0.0804, -0.5598, -0.8517,  ..., -0.2317,  0.3234, -0.0520]])}


相关文章
|
6月前
|
机器学习/深度学习 人工智能 自然语言处理
ICLR 2024 Spotlight:训练一个图神经网络即可解决图领域所有分类问题!
【2月更文挑战第17天】ICLR 2024 Spotlight:训练一个图神经网络即可解决图领域所有分类问题!
195 2
ICLR 2024 Spotlight:训练一个图神经网络即可解决图领域所有分类问题!
|
5月前
|
机器学习/深度学习 人工智能 自然语言处理
Transformer 能代替图神经网络吗?
Transformer模型的革新性在于其自注意力机制,广泛应用于多种任务,包括非原始设计领域。近期研究专注于Transformer的推理能力,特别是在图神经网络(GNN)上下文中。
96 5
|
4月前
|
机器学习/深度学习 搜索推荐 知识图谱
图神经网络加持,突破传统推荐系统局限!北大港大联合提出SelfGNN:有效降低信息过载与数据噪声影响
【7月更文挑战第22天】北大港大联手打造SelfGNN,一种结合图神经网络与自监督学习的推荐系统,专攻信息过载及数据噪声难题。SelfGNN通过短期图捕获实时用户兴趣,利用自增强学习提升模型鲁棒性,实现多时间尺度动态行为建模,大幅优化推荐准确度与时效性。经四大真实数据集测试,SelfGNN在准确性和抗噪能力上超越现有模型。尽管如此,高计算复杂度及对图构建质量的依赖仍是待克服挑战。[详细论文](https://arxiv.org/abs/2405.20878)。
80 5
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
图神经网络是一类用于处理图结构数据的神经网络。与传统的深度学习模型(如卷积神经网络CNN和循环神经网络RNN)不同,
图神经网络是一类用于处理图结构数据的神经网络。与传统的深度学习模型(如卷积神经网络CNN和循环神经网络RNN)不同,
|
4月前
|
机器学习/深度学习 编解码 数据可视化
图神经网络版本的Kolmogorov Arnold(KAN)代码实现和效果对比
目前我们看到有很多使用KAN替代MLP的实验,但是目前来说对于图神经网络来说还没有类似的实验,今天我们就来使用KAN创建一个图神经网络Graph Kolmogorov Arnold(GKAN),来测试下KAN是否可以在图神经网络方面有所作为。
187 0
|
5月前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现深度学习模型:图神经网络(GNN)
使用Python实现深度学习模型:图神经网络(GNN)
264 1
|
6月前
|
机器学习/深度学习 自然语言处理 搜索推荐
【传知代码】图神经网络长对话理解-论文复现
在ACL2023会议上发表的论文《使用带有辅助跨模态交互的关系时态图神经网络进行对话理解》提出了一种新方法,名为correct,用于多模态情感识别。correct框架通过全局和局部上下文信息捕捉对话情感,同时有效处理跨模态交互和时间依赖。模型利用图神经网络结构,通过构建图来表示对话中的交互和时间关系,提高了情感预测的准确性。在IEMOCAP和CMU-MOSEI数据集上的实验结果证明了correct的有效性。源码和更多细节可在文章链接提供的附件中获取。
【传知代码】图神经网络长对话理解-论文复现
|
5月前
|
机器学习/深度学习 搜索推荐 PyTorch
【机器学习】图神经网络:深度解析图神经网络的基本构成和原理以及关键技术
【机器学习】图神经网络:深度解析图神经网络的基本构成和原理以及关键技术
1135 2
|
6月前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
88 1
|
6月前
|
机器学习/深度学习 数据挖掘 算法框架/工具
想要了解图或图神经网络?没有比看论文更好的方式,面试阿里国际站运营一般会问什么
想要了解图或图神经网络?没有比看论文更好的方式,面试阿里国际站运营一般会问什么