千字好文,基于未采样GraphSage算子和DGL实现的图上 Edge 回归

简介: 千字好文,基于未采样GraphSage算子和DGL实现的图上 Edge 回归

千字好文,基于未采样GraphSage算子和DGL实现的图上 Edge 回归

历史图算法相关文章:

(1)一文揭开图机器学习的面纱,你确定不来看看吗


(2) graphSage还是HAN ?吐血力作综述Graph Embeding 经典好文


(3) 看这里,使用docker部署图深度学习框架GraphLearn使用说明


(4) GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推


(5) 基于GCN和DGL实现的图上 node 分类, 值得一看!!!


书接上文,在前面的几篇文章中,我们对在 图上跑机器学习/深度学习模型 有了一个大概的了解,并从 代码层面 一起 分别基于 DGL和 Graph Learn 框架实现了 链接预测 , 节点分类与回归任务 。下面让我们开始 图上边(Edge) 的回归与分类任务 的介绍吧~

(1) 图知识升华理解

在 基于GCN和DGL实现的图上 node 分类, 值得一看!!! 文章中,我们了解到,无论在图上对 链接存在与否 进行 链接预测 还是对 节点进行定性性质 的 分类回归 任务,其实我们都是 基于某个节点与它周围邻居节点的关系 来进行判断的,我们 试图让模型去学习节点周围的局部特性以及图上空间结构上的全局相似或相异模式 ,以达到根据图上关系来完成 机器学习任务 的目的。


无论在图上进行什么性质的任务,图上的 消息传播 永远是分析的 核心点 。如何构图、彼此节点的属性之间与边的属性之间各种关系如何进行融合、如何初始化节点与边的属性值、消息传播的过程如何细化利用等,这些都需要经验与技巧,才能让我们的模型学习到我们任务所需要的东西,提升模型效果。


注意: 以前的历史文章中,对于分类回归任务作者并没有可以的去进行区分,因为至少在代码实现上没什么难度,仅仅是输入和损失的更换而已。而 重点去强调 了图上的有监督与无监督,链接还是节点的预测,以及本文所说的边的回归与分类等这些模块,是因为这些模块涉及到了一些 底层的采样思想 ,针对节点和边的不同操作以及图上消息传播的理解等,这些更有 价值 一些。


其实,经过前面一些列文章的讲解, 图算法的面纱 在我们面前已经一点点的被揭下来了,至少不再是很多人眼中那么的 高不可攀 了。随着对 图结构 的认识的深入,结合业务的运用,我们也应该能够了解到:其实现实中很多问题,天然适合用图算法来进行关系建模,而有很多问题,则并不适合用图来进行解决。而还有很多问题,虽然不适合直接用图来解决,但是可以用图任务做辅助任务,来产出有意义的中间结果 Embeding,为最终的目标任务服务。


实践才是检验真理的唯一标准,实践出真知 ,接下来让我们一起学习并且运用好图机器学习相关的知识吧~

(2) 图上边分类与回归理解

在历史文章 GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推 中,我们讲到 边分类与回归,就是去预测图上的边是属于哪种类型的边,以及去预测边上的属性值。例如:图上的边分类与回归预测的可以是购买了衣服还是搜藏的衣服,以及会买几件的这个数值,分别对应着 图上边的分类与回归 任务。


根据图的结构我们知道,边是链接着两个节点的并且边也可以携带属性特征 的。既然边可以携带属性特征,那么边也就可以导出 边的embeding 。中间我们要明确一点就是,因为边是起着 桥梁 的作用,所以边的embeding必然也是和边两端的节点有着密切关系的,我们可以在更新节点的embeding过程中,顺带着就把边的embeding更新了。当然,你要固定着初始值不训练也可以。


更直接一些: 如果用户使用上一节文章中的模型计算了节点的表示,那么用户只需要再编写一个用 apply_edges() 方法计算边预测的组件即可进行边分类/回归任务。


例如,对于本文中讲解的 边回归 任务,如果用户想为每条边计算一个分数,可对每一条边计算它的两端节点隐藏表示的点积来作为分数,也可以再接入几层DNN也是可以的。

(3) 代码时光

老规矩,开篇先吼一嗓子 , talk is cheap , show me the code !!!


经过 前面 几篇文章 的介绍,想来作者对 dgl来编写图上机器学习任务 已经非常清晰了。本章的代码相对简单,下面就让我们开始吧 ~


(3.1) 导包

注意: dgl 包选择 0.9 版本,只需要下面这些包就可以了。

@ 欢迎关注微信公众号:算法全栈之路
import torch
import torch.nn as nn
import dgl
import numpy as np
import dgl.function as fn
import dgl.nn.pytorch.conv as conv
import torch.nn as nn
import torch.nn.functional as F


(3.2) 构图与赋予初始特征

注意: 这里的数据我们可以使用 numpy 或则pandas 进行赋予初始值。如果要想embeding随着网络更新,可以采用 nn.Parameter 或则 Variable 变量的形式,并让模型的梯度更新优化算法可以包含更新这个变量即可。

@ 欢迎关注微信公众号:算法全栈之路
src = np.random.randint(0, 100, 500)
dst = np.random.randint(0, 100, 500)
# 同时建立反向边
edge_pred_graph = dgl.graph((np.concatenate([src, dst]), np.concatenate([dst, src])))
# 建立点和边特征,以及边的标签
edge_pred_graph.ndata['feature'] = torch.randn(100, 10)
edge_pred_graph.edata['feature'] = torch.randn(1000, 10)
edge_pred_graph.edata['label'] = torch.randn(1000)
# 进行训练、验证和测试集划分
edge_pred_graph.edata['train_mask'] = torch.zeros(1000, dtype=torch.bool).bernoulli(0.6)

这里,我们依然构建的是同构图任务。


注意:最后的mask部分,torch.zeros(1000, dtype=torch.bool).bernoulli(0.6) 使用的这个接口进行赋值,bernoulli(0.6) 必不可少,不然会报数据类型对不住的错,都是泪啊!!!


(3.3) 模型结构定义

使用 dgl 图深度学习框架 定义的,历史文章均有介绍,这里不再展开。

@ 欢迎关注微信公众号:算法全栈之路
class DotProductPredictor(nn.Module):
    def forward(self, graph, h):
        # h是GNN模型中计算出的节点表示
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            return graph.edata['score']
class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        # 实例化SAGEConve,in_feats是输入特征的维度,out_feats是输出特征的维度,aggregator_type是聚合函数的类型
        self.conv1 = conv.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = conv.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')
    def forward(self, graph, inputs):
        # 输入是节点的特征
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h
class Model(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.sage = SAGE(in_features, hidden_features, out_features)
        self.pred = DotProductPredictor()
    def forward(self, g, x):
        h = self.sage(g, x)
        return self.pred(g, h)


从这里的代码我们可以看到: 主要的函数依然是 Sage 算子 和 DotProductPredictor方法 ,这个我们在 GraphSage与DGL实现同构图 Link 预测,通俗易懂好文强推 有了介绍,这里就不再赘述了。


这里要 强调 的一点就是:图上边的 predict的值 是由 DotProductPredictor 这个方法点积计算得到的。我们也可以采用 几层DNN来融合两个节点的embeding , 就像这样 :

@ 欢迎关注微信公众号:算法全栈之路
class MLPPredictor(nn.Module):
    def __init__(self, in_features, out_classes):
        super().__init__()
        self.W = nn.Linear(in_features * 2, out_classes)
    def apply_edges(self, edges):
        h_u = edges.src['h']
        h_v = edges.dst['h']
        score = self.W(torch.cat([h_u, h_v], 1))
        # 这里是不是可以得到一个一定维度的边的embeding
        return {'score': score}
    def forward(self, graph, h):
        # h是从5.1节的GNN模型中计算出的节点表示
        with graph.local_scope():
            graph.ndata['h'] = h
            graph.apply_edges(self.apply_edges)
            return graph.edata['score']

注意: 在上面的 apply_edges 方法里,我们是不是可以 得到一个一定维度的边的embeding呢,然后把它存入到边的属性特征中,随着网络更新而更新。那这样,我们是不是最后就可以到处图上边的embeding了呢。


(3.4) 模型训练

接下来,就是 模型训练了。注意,这里以边的回归算法为例进行说明,更改为边的分类算法也非常简单,读者可以下去自己修改下~

@ 欢迎关注微信公众号:算法全栈之路
# model train
# 在训练模型时可以使用布尔掩码区分训练、验证和测试数据集。该例子里省略了训练早停和模型保存部分的代码。
node_features = edge_pred_graph.ndata['feature']
edge_label = edge_pred_graph.edata['label']
train_mask = edge_pred_graph.edata['train_mask']
model = Model(10, 20, 5)
opt = torch.optim.Adam(model.parameters())
for epoch in range(10):
    pred = model(edge_pred_graph,node_features)
    # mse loss 损失
    loss = ((pred[train_mask] - edge_label[train_mask]) ** 2).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

这里主要需要留意下 train mask和损失相关 的内容,代码通俗易懂,我就不再赘述了。


到这里,千字好文,基于未采样GraphSage算子和DGL实现的图上 Edge 回归 的全文就写完了。上面的代码demo 在环境没问题的情况下,全部复制到一个python文件里,就可以完美运行起来。


接下来会针对异构体写一篇复杂任务的文章,欢迎关注~


码字不易,觉得有收获就动动小手转载一下吧,你的支持是我写下去的最大动力 ~


更多更全更新内容: 算法全栈之路


相关文章
|
9月前
|
SQL 安全 数据格式
PEP 750 t-string 深度解析:与 f-string 的差异与进化
Python 3.14 即将引入的 t-string(模板字符串)是字符串处理的重大革新。作为 f-string 的继任者,t-string 通过延迟渲染机制重新定义了字符串模板处理方式。本文从核心机制(即时求值 vs 延迟渲染)、技术特性(语法到语义进化)、应用场景(安全敏感场景、复杂模板系统等)及性能兼容性等方面深入解析,展示其在安全框架、代码生成等领域的广阔前景。开发者可根据需求选择 f-string 或 t-string,实现更高效、可控的字符串处理。
341 13
|
存储 关系型数据库 MySQL
浅谈Elasticsearch的入门与实践
本文主要围绕ES核心特性:分布式存储特性和分析检索能力,介绍了概念、原理与实践案例,希望让读者快速理解ES的核心特性与应用场景。
923 14
|
9月前
|
人工智能 自然语言处理 开发工具
HarmonyOS NEXT~鸿蒙开发能力:HarmonyOS SDK AI 全解析
本文深入解析HarmonyOS SDK中的AI功能集,涵盖分布式AI引擎、核心组件(NLP、计算机视觉等)及智能决策能力。通过代码示例与开发实践指南,帮助开发者掌握环境配置、性能调优及多场景应用(智能家居、移动办公等)。同时探讨性能优化策略与未来演进方向,助力构建高效分布式智能应用。
1049 9
|
存储 固态存储 Java
文件系统使用固态硬盘(SSD)
【10月更文挑战第8天】
551 2
|
前端开发
一键复制微信聊天框效果:HTML+CSS让网页聊天更生动!
一键复制微信聊天框效果:HTML+CSS让网页聊天更生动!
|
安全 网络安全 网络虚拟化
优化大型企业网络架构:从核心到边缘的全面升级
大型企业在业务运作中涉及多种数据传输,涵盖办公应用、CRM/ERP系统、数据中心、云环境、物联网及安全合规等多个方面。其复杂的业务生态和全球布局要求网络架构具备高效、安全和可靠的特性。网络设计需全面考虑核心层、汇聚层和接入层的功能与冗余,同时实现内外部的有效连接,包括广域网连接、远程访问策略、云计算集成及多层次安全防护,以构建高效且可扩展的网络生态系统。
优化大型企业网络架构:从核心到边缘的全面升级
|
机器学习/深度学习 API 计算机视觉
如何使用深度学习实现图像分类
深度学习在图像分类中扮演着核心角色,通过卷积神经网络(CNN)自动提取图像特征并分类。本文介绍深度学习原理及其实现流程,包括数据准备、构建CNN模型、训练与评估模型,并讨论如何在阿里云上部署模型及其实用场景。
|
安全 小程序 Java
基于Java医院门诊互联电子病历管理信息系统设计和实现(源码+LW+调试文档+讲解等)
基于Java医院门诊互联电子病历管理信息系统设计和实现(源码+LW+调试文档+讲解等)
H8
|
存储 传感器 机器学习/深度学习
数字孪生(Digital Twins)
数字映射(Digital twin),或译作数字孪生、数字分身,指在信息化平台内模拟物理实体、流程或者系统,类似实体系统在信息化平台中的双胞胎。借助于数字映射,可以在信息化平台上了解物理实体的状态,甚至可以对物理实体里面预定义的接口组件进行控制。
H8
1071 1
|
数据中心 云计算 网络架构

热门文章

最新文章