以PyG官方的数据集和示例代码来复现一下这个问题:
import torch import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG from torch_geometric.nn import SAGEConv, to_hetero dataset = OGB_MAG(root='files/pyg_data', preprocess='metapath2vec') data = dataset[0] print(data) class GNN(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x model = GNN(hidden_channels=64, out_channels=dataset.num_classes) model = to_hetero(model, data.metadata(), aggr='sum') with torch.no_grad(): # Initialize lazy modules. out = model(data.x_dict, data.edge_index_dict) print(out)
输出信息:
HeteroData( paper={ x=[736389, 128], year=[736389], y=[736389], train_mask=[736389], val_mask=[736389], test_mask=[736389] }, author={ x=[1134649, 128] }, institution={ x=[8740, 128] }, field_of_study={ x=[59965, 128] }, (author, affiliated_with, institution)={ edge_index=[2, 1043998] }, (author, writes, paper)={ edge_index=[2, 7145660] }, (paper, cites, paper)={ edge_index=[2, 5416271] }, (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] } ) my_env/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py:145: UserWarning: There exist node types ({'author'}) whose representations do not get updated during message passing as they do not occur as destination type in any edge type. This may lead to unexpected behaviour. warnings.warn( Traceback (most recent call last): File "try2.py", line 25, in <module> model = to_hetero(model, data.metadata(), aggr='sum') File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 118, in to_hetero return transformer.transform() File "env_path/lib/python3.8/site-packages/torch_geometric/nn/fx.py", line 157, in transform getattr(self, op)(node, node.target, node.name) File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 294, in call_method args, kwargs = self.map_args_kwargs(node, key) File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 397, in map_args_kwargs args = tuple(_recurse(v) for v in node.args) File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 397, in <genexpr> args = tuple(_recurse(v) for v in node.args) File "env_path/lib/python3.8/site-packages/torch_geometric/nn/to_hetero_transformer.py", line 387, in _recurse raise NotImplementedError NotImplementedError
可以很容易地看出来,这是由于有一种节点没有入边产生的问题。
解决方案就是使所有节点都有入边。如:
import torch import torch_geometric.transforms as T from torch_geometric.datasets import OGB_MAG from torch_geometric.nn import SAGEConv, to_hetero dataset = OGB_MAG(root='files/pyg_data', preprocess='metapath2vec',transform=T.ToUndirected()) data = dataset[0] print(data) class GNN(torch.nn.Module): def __init__(self, hidden_channels, out_channels): super().__init__() self.conv1 = SAGEConv((-1, -1), hidden_channels) self.conv2 = SAGEConv((-1, -1), out_channels) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = self.conv2(x, edge_index) return x model = GNN(hidden_channels=64, out_channels=dataset.num_classes) model = to_hetero(model, data.metadata(), aggr='sum') with torch.no_grad(): # Initialize lazy modules. out = model(data.x_dict, data.edge_index_dict) print(out)
将异质图转换为无向图,这样就能得到正常的输出结果:
HeteroData( paper={ x=[736389, 128], year=[736389], y=[736389], train_mask=[736389], val_mask=[736389], test_mask=[736389] }, author={ x=[1134649, 128] }, institution={ x=[8740, 128] }, field_of_study={ x=[59965, 128] }, (author, affiliated_with, institution)={ edge_index=[2, 1043998] }, (author, writes, paper)={ edge_index=[2, 7145660] }, (paper, cites, paper)={ edge_index=[2, 10792672] }, (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] }, (institution, rev_affiliated_with, author)={ edge_index=[2, 1043998] }, (paper, rev_writes, author)={ edge_index=[2, 7145660] }, (field_of_study, rev_has_topic, paper)={ edge_index=[2, 7505078] } ) {'paper': tensor([[-0.8212, -0.2630, -0.7286, ..., 1.1904, 0.1617, -0.5388], [-1.2484, -0.3707, -1.0336, ..., 0.9618, -0.0373, -0.1125], [-0.5375, 0.0357, -0.6772, ..., 1.2185, 0.2292, -0.2130], ..., [-0.9934, -0.2688, -0.9547, ..., 1.3144, 0.1519, -0.2015], [-1.4711, -0.6607, -0.7509, ..., 2.3383, 0.6815, -1.0679], [-0.4352, -0.4255, -0.6907, ..., 1.1532, 0.1152, -0.9703]]), 'author': tensor([[-0.2782, 0.1771, 0.4187, ..., -0.5233, -0.2969, 0.2438], [-0.4543, 0.1019, 0.1637, ..., -0.7748, -0.2809, 0.2598], [-0.1613, -0.0481, -0.2491, ..., -0.6227, -0.4217, 0.1335], ..., [-0.4908, 0.2382, 0.2973, ..., -0.7266, -0.2486, 0.6449], [-0.2819, 0.0125, 0.9843, ..., -1.9652, -0.4280, -0.4842], [-0.4236, -0.1222, 1.0246, ..., -2.0615, -0.3246, -0.1771]]), 'institution': tensor([[ 0.3911, -1.3527, -0.6624, ..., 0.2732, 0.5270, 0.5756], [ 0.1512, -0.6687, -0.6516, ..., 0.1482, 0.2535, 0.1935], [ 0.1933, -1.1643, -0.4936, ..., 0.5382, 0.3407, 0.2199], ..., [ 0.1489, -0.3021, -0.3390, ..., 0.2690, 0.1571, -0.0781], [ 0.1855, -0.4848, -0.3205, ..., 0.4728, 0.0659, 0.1500], [ 0.1724, -0.0682, -0.0894, ..., 0.1189, 0.1230, -0.2249]]), 'field_of_study': tensor([[ 0.1929, -0.5402, -0.5714, ..., -0.4296, 0.4376, -0.0660], [-0.2281, 0.0773, -0.0486, ..., -0.0544, -0.2894, 0.2706], [-0.2798, -0.1967, -0.3376, ..., -0.3098, -0.1610, 0.1120], ..., [ 0.0775, -0.5927, -0.6084, ..., -0.3190, 0.2483, -0.1418], [ 0.0286, -0.7393, -0.6629, ..., -0.4745, 0.8461, -0.1554], [-0.0804, -0.5598, -0.8517, ..., -0.2317, 0.3234, -0.0520]])}