1. 数据获取
2. 数据预处理
- 对节点特征进行行归一化(T.NormalizeFeatures(),文档https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.NormalizeFeatures,源码torch_geometric.transforms.normalize_features — pytorch_geometric documentation):使每一行总和为1、且更稀疏,具体做法是:元素减去最小值,然后除以总值(设置最小值为1)
- 将DataSet对象放到GPU上(T.ToDevice(device),文档https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html#torch_geometric.transforms.ToDevice)
- 对DataSet对象用链路预测的方法进行数据集划分:
T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False)
import torch import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') transform = T.Compose([ T.NormalizeFeatures(), T.ToDevice(device), T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False), ]) dataset = Planetoid('pyg_data/Planetoid', name='Cora', transform=transform) print(type(dataset)) train_data, val_data, test_data = dataset[0] print(type(train_data))
<class 'torch_geometric.datasets.planetoid.Planetoid'> <class 'torch_geometric.data.data.Data'>
3. 建立链路预测模型
- encode()函数:GNN节点表征,使用2层GCN,其中用了ReLU激活函数。没有其他trick。
- decode()函数在训练时使用,仅计算指定edge_label_index上的边,在代码上用逐元素求和表示点积。
- decode_all()函数在测试时使用,计算整张图所有节点对存在边的概率,也是用矩阵乘法来实现点积,结果的概率大于0直接认为节点对之间存在边,返回的是这个被认为存在边的edge list。
import torch from torch_geometric.nn import GCNConv class Net(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def encode(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index) def decode(self, z, edge_label_index): return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1) def decode_all(self, z): prob_adj = z @ z.t() return (prob_adj > 0).nonzero(as_tuple=False).t()
4. 实例化模型,设置优化器、损失函数
torch.nn.BCEWithLogitsLoss()) model = Net(dataset.num_features, 128, 64).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) criterion = torch.nn.BCEWithLogitsLoss()
5. 构建训练函数
from torch_geometric.utils import negative_sampling def train(): model.train() optimizer.zero_grad() z = model.encode(train_data.x, train_data.edge_index) # We perform a new round of negative sampling for every training epoch: neg_edge_index = negative_sampling( edge_index=train_data.edge_index, num_nodes=train_data.num_nodes, num_neg_samples=train_data.edge_label_index.size(1), method='sparse') edge_label_index = torch.cat( [train_data.edge_label_index, neg_edge_index], dim=-1, ) edge_label = torch.cat([ train_data.edge_label, train_data.edge_label.new_zeros(neg_edge_index.size(1)) ], dim=0) out = model.decode(z, edge_label_index).view(-1) loss = criterion(out, edge_label) loss.backward() optimizer.step() return loss
6. 构建每个epoch运行时的测试函数
我个人比较喜欢用with torch.no_grad()
计算图数据上正边的概率,直接用其通过Sigmoid激活函数后的结果作为边存在的概率,用以计算ROC AUC值。
@torch.no_grad() def test(data): model.eval() z = model.encode(data.x, data.edge_index) out = model.decode(z, data.edge_label_index).view(-1).sigmoid() return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())
7. 训练和测试
best_val_auc = final_test_auc = 0 for epoch in range(1, 101): loss = train() val_auc = test(val_data) test_auc = test(test_data) if val_auc > best_val_auc: best_val_auc = val_auc final_test_auc = test_auc print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, ' f'Test: {test_auc:.4f}') print(f'Final Test: {final_test_auc:.4f}') z = model.encode(test_data.x, test_data.edge_index) final_edge_index = model.decode_all(z)
Epoch: 001, Loss: 0.6930, Val: 0.6729, Test: 0.7026 Epoch: 002, Loss: 0.6820, Val: 0.6589, Test: 0.6913 Epoch: 003, Loss: 0.7065, Val: 0.6619, Test: 0.6967 Epoch: 004, Loss: 0.6766, Val: 0.6686, Test: 0.7069 Epoch: 005, Loss: 0.6842, Val: 0.6716, Test: 0.7128 Epoch: 006, Loss: 0.6876, Val: 0.6637, Test: 0.7132 Epoch: 007, Loss: 0.6881, Val: 0.6471, Test: 0.7009 Epoch: 008, Loss: 0.6867, Val: 0.6317, Test: 0.6859 Epoch: 009, Loss: 0.6829, Val: 0.6240, Test: 0.6767 Epoch: 010, Loss: 0.6765, Val: 0.6223, Test: 0.6720 Epoch: 011, Loss: 0.6715, Val: 0.6208, Test: 0.6684 Epoch: 012, Loss: 0.6759, Val: 0.6204, Test: 0.6640 Epoch: 013, Loss: 0.6687, Val: 0.6272, Test: 0.6656 Epoch: 014, Loss: 0.6621, Val: 0.6488, Test: 0.6778 Epoch: 015, Loss: 0.6593, Val: 0.6748, Test: 0.6907 Epoch: 016, Loss: 0.6534, Val: 0.6824, Test: 0.6923 Epoch: 017, Loss: 0.6477, Val: 0.6796, Test: 0.6867 Epoch: 018, Loss: 0.6389, Val: 0.6847, Test: 0.6888 Epoch: 019, Loss: 0.6332, Val: 0.7155, Test: 0.7115 Epoch: 020, Loss: 0.6217, Val: 0.7487, Test: 0.7430 Epoch: 021, Loss: 0.6060, Val: 0.7645, Test: 0.7582 Epoch: 022, Loss: 0.5993, Val: 0.7650, Test: 0.7574 Epoch: 023, Loss: 0.5837, Val: 0.7632, Test: 0.7550 Epoch: 024, Loss: 0.5719, Val: 0.7612, Test: 0.7530 Epoch: 025, Loss: 0.5654, Val: 0.7565, Test: 0.7518 Epoch: 026, Loss: 0.5697, Val: 0.7574, Test: 0.7534 Epoch: 027, Loss: 0.5676, Val: 0.7610, Test: 0.7576 Epoch: 028, Loss: 0.5551, Val: 0.7629, Test: 0.7634 Epoch: 029, Loss: 0.5446, Val: 0.7682, Test: 0.7723 Epoch: 030, Loss: 0.5422, Val: 0.7774, Test: 0.7848 Epoch: 031, Loss: 0.5259, Val: 0.7896, Test: 0.7988 Epoch: 032, Loss: 0.5277, Val: 0.8005, Test: 0.8127 Epoch: 033, Loss: 0.5218, Val: 0.8135, Test: 0.8245 Epoch: 034, Loss: 0.5156, Val: 0.8234, Test: 0.8342 Epoch: 035, Loss: 0.5057, Val: 0.8285, Test: 0.8414 Epoch: 036, Loss: 0.4981, Val: 0.8314, Test: 0.8462 Epoch: 037, Loss: 0.4984, Val: 0.8302, Test: 0.8459 Epoch: 038, Loss: 0.4960, Val: 0.8332, Test: 0.8489 Epoch: 039, Loss: 0.4873, Val: 0.8381, Test: 0.8555 Epoch: 040, Loss: 0.4883, Val: 0.8418, Test: 0.8609 Epoch: 041, Loss: 0.4993, Val: 0.8427, Test: 0.8615 Epoch: 042, Loss: 0.4852, Val: 0.8452, Test: 0.8616 Epoch: 043, Loss: 0.4718, Val: 0.8474, Test: 0.8640 Epoch: 044, Loss: 0.4768, Val: 0.8492, Test: 0.8679 Epoch: 045, Loss: 0.4708, Val: 0.8472, Test: 0.8688 Epoch: 046, Loss: 0.4726, Val: 0.8457, Test: 0.8680 Epoch: 047, Loss: 0.4729, Val: 0.8500, Test: 0.8698 Epoch: 048, Loss: 0.4726, Val: 0.8517, Test: 0.8705 Epoch: 049, Loss: 0.4730, Val: 0.8527, Test: 0.8722 Epoch: 050, Loss: 0.4715, Val: 0.8521, Test: 0.8734 Epoch: 051, Loss: 0.4667, Val: 0.8547, Test: 0.8756 Epoch: 052, Loss: 0.4609, Val: 0.8577, Test: 0.8784 Epoch: 053, Loss: 0.4632, Val: 0.8607, Test: 0.8829 Epoch: 054, Loss: 0.4612, Val: 0.8626, Test: 0.8862 Epoch: 055, Loss: 0.4591, Val: 0.8646, Test: 0.8878 Epoch: 056, Loss: 0.4568, Val: 0.8644, Test: 0.8874 Epoch: 057, Loss: 0.4569, Val: 0.8656, Test: 0.8874 Epoch: 058, Loss: 0.4568, Val: 0.8688, Test: 0.8897 Epoch: 059, Loss: 0.4516, Val: 0.8721, Test: 0.8929 Epoch: 060, Loss: 0.4567, Val: 0.8729, Test: 0.8942 Epoch: 061, Loss: 0.4625, Val: 0.8742, Test: 0.8938 Epoch: 062, Loss: 0.4547, Val: 0.8729, Test: 0.8919 Epoch: 063, Loss: 0.4479, Val: 0.8723, Test: 0.8927 Epoch: 064, Loss: 0.4517, Val: 0.8728, Test: 0.8962 Epoch: 065, Loss: 0.4517, Val: 0.8719, Test: 0.8972 Epoch: 066, Loss: 0.4538, Val: 0.8726, Test: 0.8962 Epoch: 067, Loss: 0.4532, Val: 0.8718, Test: 0.8944 Epoch: 068, Loss: 0.4540, Val: 0.8725, Test: 0.8937 Epoch: 069, Loss: 0.4542, Val: 0.8734, Test: 0.8953 Epoch: 070, Loss: 0.4487, Val: 0.8726, Test: 0.8967 Epoch: 071, Loss: 0.4497, Val: 0.8727, Test: 0.8973 Epoch: 072, Loss: 0.4539, Val: 0.8694, Test: 0.8949 Epoch: 073, Loss: 0.4478, Val: 0.8703, Test: 0.8937 Epoch: 074, Loss: 0.4449, Val: 0.8737, Test: 0.8945 Epoch: 075, Loss: 0.4486, Val: 0.8770, Test: 0.8968 Epoch: 076, Loss: 0.4491, Val: 0.8724, Test: 0.8970 Epoch: 077, Loss: 0.4431, Val: 0.8678, Test: 0.8957 Epoch: 078, Loss: 0.4447, Val: 0.8688, Test: 0.8952 Epoch: 079, Loss: 0.4540, Val: 0.8704, Test: 0.8943 Epoch: 080, Loss: 0.4548, Val: 0.8741, Test: 0.8955 Epoch: 081, Loss: 0.4468, Val: 0.8746, Test: 0.8985 Epoch: 082, Loss: 0.4495, Val: 0.8727, Test: 0.8994 Epoch: 083, Loss: 0.4473, Val: 0.8708, Test: 0.8990 Epoch: 084, Loss: 0.4464, Val: 0.8715, Test: 0.8976 Epoch: 085, Loss: 0.4376, Val: 0.8755, Test: 0.8977 Epoch: 086, Loss: 0.4455, Val: 0.8762, Test: 0.8993 Epoch: 087, Loss: 0.4442, Val: 0.8727, Test: 0.9004 Epoch: 088, Loss: 0.4411, Val: 0.8726, Test: 0.9009 Epoch: 089, Loss: 0.4445, Val: 0.8760, Test: 0.9010 Epoch: 090, Loss: 0.4474, Val: 0.8780, Test: 0.9002 Epoch: 091, Loss: 0.4468, Val: 0.8754, Test: 0.9009 Epoch: 092, Loss: 0.4470, Val: 0.8712, Test: 0.9015 Epoch: 093, Loss: 0.4467, Val: 0.8680, Test: 0.9006 Epoch: 094, Loss: 0.4454, Val: 0.8720, Test: 0.9019 Epoch: 095, Loss: 0.4355, Val: 0.8761, Test: 0.9028 Epoch: 096, Loss: 0.4486, Val: 0.8749, Test: 0.9013 Epoch: 097, Loss: 0.4418, Val: 0.8695, Test: 0.8999 Epoch: 098, Loss: 0.4396, Val: 0.8651, Test: 0.9002 Epoch: 099, Loss: 0.4365, Val: 0.8684, Test: 0.9034 Epoch: 100, Loss: 0.4428, Val: 0.8720, Test: 0.9050 Final Test: 0.9002 torch.Size([2, 3262820])
8. 整体代码
import torch from sklearn.metrics import roc_auc_score import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid from torch_geometric.nn import GCNConv from torch_geometric.utils import negative_sampling device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') transform = T.Compose([ T.NormalizeFeatures(), T.ToDevice(device), T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=False), ]) dataset = Planetoid('/data/pyg_data/Planetoid', name='Cora', transform=transform) train_data, val_data, test_data = dataset[0] class Net(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super().__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, out_channels) def encode(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index) def decode(self, z, edge_label_index): return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1) def decode_all(self, z): prob_adj = z @ z.t() return (prob_adj > 0).nonzero(as_tuple=False).t() model = Net(dataset.num_features, 128, 64).to(device) optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01) criterion = torch.nn.BCEWithLogitsLoss() def train(): model.train() optimizer.zero_grad() z = model.encode(train_data.x, train_data.edge_index) # We perform a new round of negative sampling for every training epoch: neg_edge_index = negative_sampling( edge_index=train_data.edge_index, num_nodes=train_data.num_nodes, num_neg_samples=train_data.edge_label_index.size(1), method='sparse') edge_label_index = torch.cat( [train_data.edge_label_index, neg_edge_index], dim=-1, ) edge_label = torch.cat([ train_data.edge_label, train_data.edge_label.new_zeros(neg_edge_index.size(1)) ], dim=0) out = model.decode(z, edge_label_index).view(-1) loss = criterion(out, edge_label) loss.backward() optimizer.step() return loss @torch.no_grad() def test(data): model.eval() z = model.encode(data.x, data.edge_index) out = model.decode(z, data.edge_label_index).view(-1).sigmoid() return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy()) best_val_auc = final_test_auc = 0 for epoch in range(1, 101): loss = train() val_auc = test(val_data) test_auc = test(test_data) if val_auc > best_val_auc: best_val_auc = val_auc final_test_auc = test_auc print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, ' f'Test: {test_auc:.4f}') print(f'Final Test: {final_test_auc:.4f}') z = model.encode(test_data.x, test_data.edge_index) final_edge_index = model.decode_all(z) print(final_edge_index.size())