千字好文,基于未采样GraphSage算子和DGL实现的图上 Edge 回归

简介: 千字好文,基于未采样GraphSage算子和DGL实现的图上 Edge 回归

千字好文,基于未采样GraphSage算子和DGL实现的图上 Edge 回归

历史图算法相关文章:

(1)一文揭开图机器学习的面纱,你确定不来看看吗


(2) graphSage还是HAN ?吐血力作综述Graph Embeding 经典好文


(3) 看这里,使用docker部署图深度学习框架GraphLearn使用说明


(4) GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推


(5) 基于GCN和DGL实现的图上 node 分类, 值得一看!!!


书接上文,在前面的几篇文章中,我们对在 图上跑机器学习/深度学习模型 有了一个大概的了解,并从 代码层面 一起 分别基于 DGL和 Graph Learn 框架实现了 链接预测 , 节点分类与回归任务 。下面让我们开始 图上边(Edge) 的回归与分类任务 的介绍吧~

(1) 图知识升华理解

在 基于GCN和DGL实现的图上 node 分类, 值得一看!!! 文章中,我们了解到,无论在图上对 链接存在与否 进行 链接预测 还是对 节点进行定性性质 的 分类回归 任务,其实我们都是 基于某个节点与它周围邻居节点的关系 来进行判断的,我们 试图让模型去学习节点周围的局部特性以及图上空间结构上的全局相似或相异模式 ,以达到根据图上关系来完成 机器学习任务 的目的。


无论在图上进行什么性质的任务,图上的 消息传播 永远是分析的 核心点 。如何构图、彼此节点的属性之间与边的属性之间各种关系如何进行融合、如何初始化节点与边的属性值、消息传播的过程如何细化利用等,这些都需要经验与技巧,才能让我们的模型学习到我们任务所需要的东西,提升模型效果。


注意: 以前的历史文章中,对于分类回归任务作者并没有可以的去进行区分,因为至少在代码实现上没什么难度,仅仅是输入和损失的更换而已。而 重点去强调 了图上的有监督与无监督,链接还是节点的预测,以及本文所说的边的回归与分类等这些模块,是因为这些模块涉及到了一些 底层的采样思想 ,针对节点和边的不同操作以及图上消息传播的理解等,这些更有 价值 一些。


其实,经过前面一些列文章的讲解, 图算法的面纱 在我们面前已经一点点的被揭下来了,至少不再是很多人眼中那么的 高不可攀 了。随着对 图结构 的认识的深入,结合业务的运用,我们也应该能够了解到:其实现实中很多问题,天然适合用图算法来进行关系建模,而有很多问题,则并不适合用图来进行解决。而还有很多问题,虽然不适合直接用图来解决,但是可以用图任务做辅助任务,来产出有意义的中间结果 Embeding,为最终的目标任务服务。


实践才是检验真理的唯一标准,实践出真知 ,接下来让我们一起学习并且运用好图机器学习相关的知识吧~

(2) 图上边分类与回归理解

在历史文章 GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推 中,我们讲到 边分类与回归,就是去预测图上的边是属于哪种类型的边,以及去预测边上的属性值。例如:图上的边分类与回归预测的可以是购买了衣服还是搜藏的衣服,以及会买几件的这个数值,分别对应着 图上边的分类与回归 任务。


根据图的结构我们知道,边是链接着两个节点的并且边也可以携带属性特征 的。既然边可以携带属性特征,那么边也就可以导出 边的embeding 。中间我们要明确一点就是,因为边是起着 桥梁 的作用,所以边的embeding必然也是和边两端的节点有着密切关系的,我们可以在更新节点的embeding过程中,顺带着就把边的embeding更新了。当然,你要固定着初始值不训练也可以。


更直接一些: 如果用户使用上一节文章中的模型计算了节点的表示,那么用户只需要再编写一个用 apply_edges() 方法计算边预测的组件即可进行边分类/回归任务。


例如,对于本文中讲解的 边回归 任务,如果用户想为每条边计算一个分数,可对每一条边计算它的两端节点隐藏表示的点积来作为分数,也可以再接入几层DNN也是可以的。

(3) 代码时光

老规矩,开篇先吼一嗓子 , talk is cheap , show me the code !!!


经过 前面 几篇文章 的介绍,想来作者对 dgl来编写图上机器学习任务 已经非常清晰了。本章的代码相对简单,下面就让我们开始吧 ~


(3.1) 导包

注意: dgl 包选择 0.9 版本,只需要下面这些包就可以了。

@ 欢迎关注微信公众号:算法全栈之路
import torch
import torch.nn as nn
import dgl
import numpy as np
import dgl.function as fn
import dgl.nn.pytorch.conv as conv
import torch.nn as nn
import torch.nn.functional as F


(3.2) 构图与赋予初始特征

注意: 这里的数据我们可以使用 numpy 或则pandas 进行赋予初始值。如果要想embeding随着网络更新,可以采用 nn.Parameter 或则 Variable 变量的形式,并让模型的梯度更新优化算法可以包含更新这个变量即可。

@ 欢迎关注微信公众号:算法全栈之路
src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
# 同时建立反向边
edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))
# 建立点和边特征,以及边的标签
edge_pred_graph.ndata['feature'] = torch.randn(100, 10)
edge_pred_graph.edata['feature'] = torch.randn(1000, 10)
edge_pred_graph.edata['label'] = torch.randn(1000)
# 进行训练、验证和测试集划分
edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)

这里,我们依然构建的是同构图任务。


注意:最后的mask部分,torch.zeros(1000, dtype=torch.bool).bernoulli(0.6) 使用的这个接口进行赋值,bernoulli(0.6) 必不可少,不然会报数据类型对不住的错,都是泪啊!!!


(3.3) 模型结构定义

使用 dgl 图深度学习框架 定义的,历史文章均有介绍,这里不再展开。

@ 欢迎关注微信公众号:算法全栈之路
class DotProductPredictor(nn.Module):
    def forward(self, graph, h):
        # h是GNN模型中计算出的节点表示
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            return graph.edata['score']
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        # 实例化SAGEConve,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregator_type是聚合函数的类型
        self.conv1 = conv.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = conv.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')
    def forward(self, graph, inputs):
        # 输入是节点的特征
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.sage = SAGE(in_features, hidden_features, out_features)
        self.pred = DotProductPredictor()
    def forward(self, g, x):
        h = self.sage(g, x)
        return self.pred(g, h)


从这里的代码我们可以看到: 主要的函数依然是 Sage 算子 和 DotProductPredictor方法 ,这个我们在 GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推 有了介绍,这里就不再赘述了。


这里要 强调 的一点就是:图上边的 predict的值 是由 DotProductPredictor 这个方法点积计算得到的。我们也可以采用 几层DNN来融合两个节点的embeding , 就像这样 :

@ 欢迎关注微信公众号:算法全栈之路
class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)
    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        # 这里是不是可以得到一个一定维度的边的embeding
        return {'score': score}
    def forward(self, graph, h):
        # h是从5.1节的GNN模型中计算出的节点表示
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

注意: 在上面的 apply_edges 方法里,我们是不是可以 得到一个一定维度的边的embeding呢,然后把它存入到边的属性特征中,随着网络更新而更新。那这样,我们是不是最后就可以到处图上边的embeding了呢。


(3.4) 模型训练

接下来,就是 模型训练了。注意,这里以边的回归算法为例进行说明,更改为边的分类算法也非常简单,读者可以下去自己修改下~

@ 欢迎关注微信公众号:算法全栈之路
# model train
# 在训练模型时可以使用布尔掩码区分训练、验证和测试数据集。该例子里省略了训练早停和模型保存部分的代码。
node_features = edge_pred_graph.ndata['feature']
edge_label = edge_pred_graph.edata['label']
train_mask = edge_pred_graph.edata['train_mask']
model = Model(10, 20, 5)
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    pred = model(edge_pred_graph,node_features)
    # mse loss 损失
    loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

这里主要需要留意下 train mask和损失相关 的内容,代码通俗易懂,我就不再赘述了。


到这里,千字好文,基于未采样GraphSage算子和DGL实现的图上 Edge 回归 的全文就写完了。上面的代码demo 在环境没问题的情况下,全部复制到一个python文件里,就可以完美运行起来。


接下来会针对异构体写一篇复杂任务的文章,欢迎关注~


码字不易,觉得有收获就动动小手转载一下吧,你的支持是我写下去的最大动力 ~


更多更全更新内容: 算法全栈之路


相关文章
|
3天前
|
数据可视化 算法
R语言近似贝叶斯计算MCMC(ABC-MCMC)轨迹图和边缘图可视化
R语言近似贝叶斯计算MCMC(ABC-MCMC)轨迹图和边缘图可视化
|
资源调度 算法 计算机视觉
数字图像处理实验(六)|图像分割{阈值分割、直方图法、OTUS最大类间方差法(edge、im2dw、imfilter、imresize)、迭代阈值法、点检测}(附matlab实验代码和截图)
数字图像处理实验(六)|图像分割{阈值分割、直方图法、OTUS最大类间方差法(edge、im2dw、imfilter、imresize)、迭代阈值法、点检测}(附matlab实验代码和截图)
659 0
数字图像处理实验(六)|图像分割{阈值分割、直方图法、OTUS最大类间方差法(edge、im2dw、imfilter、imresize)、迭代阈值法、点检测}(附matlab实验代码和截图)
|
3天前
|
机器学习/深度学习 人工智能 数据可视化
【视频】R语言支持向量回归SVR预测水位实例讲解|附代码数据
【视频】R语言支持向量回归SVR预测水位实例讲解|附代码数据
|
3天前
|
机器学习/深度学习 数据可视化 PyTorch
基于TorchViz详解计算图(附代码)
基于TorchViz详解计算图(附代码)
91 0
|
3天前
|
机器学习/深度学习 Python
YOLOv8改进 | 进阶实战篇 | 利用YOLOv8进行过线统计(可用于人 、车过线统计)
YOLOv8改进 | 进阶实战篇 | 利用YOLOv8进行过线统计(可用于人 、车过线统计)
114 0
|
3天前
用图直观上理解梯度算子(一阶)与拉普拉斯算子(二阶)的区别,线检测与边缘检测的区别
用图直观上理解梯度算子(一阶)与拉普拉斯算子(二阶)的区别,线检测与边缘检测的区别
62 1
|
3天前
|
机器学习/深度学习 自动驾驶 安全
【论文速递】Arxiv2019 - MultiPath:行为预测的多重概率锚点轨迹假设
【论文速递】Arxiv2019 - MultiPath:行为预测的多重概率锚点轨迹假设
71 0
|
11月前
|
数据挖掘
R-apply| 基因表达量批量二分类,Get!(修正版)
R-apply| 基因表达量批量二分类,Get!(修正版)
|
11月前
|
数据挖掘 Serverless
Robust火山图:一种含离群值的代谢组数据差异分析方法
代谢组学中差异代谢物的识别仍然是一个巨大的挑战,并在代谢组学数据分析中发挥着突出的作用。由于分析、实验和生物的模糊性,代谢组学数据集经常包含异常值,但目前可用的差异代谢物识别技术对异常值很敏感。作者这里提出了一种基于权重的具有稳健性火山图方法,助于从含有离群值的代谢组数据中更加准确鉴定差异代谢物。
129 0
|
机器学习/深度学习 数据采集 算法
迁移学习「求解」偏微分方程,条件偏移下PDE的深度迁移算子学习(1)
迁移学习「求解」偏微分方程,条件偏移下PDE的深度迁移算子学习