【图神经网络DGL】消息传递范式

简介: 【图神经网络DGL】消息传递范式消息函数:接受一个参数 edges,这是一个 dgl.EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。edges有三个成员属性:src、dst和data,分别用于访问源节点、目标节点和边的特征。

一、消息传递范式

聚合函数和更新函数。

1111.png

1.1 内置函数和消息传递API

(1)API属性介绍

消息函数:接受一个参数 edges,这是一个 dgl.EdgeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批边。edges有三个成员属性:src、dst和data,分别用于访问源节点、目标节点和边的特征。

聚合函数:接受一个参数 nodes,这是一个 NodeBatch 的实例, 在消息传递时,它被DGL在内部生成以表示一批节点。 nodes 的成员属性 mailbox 可以用来访问节点收到的消息。 一些最常见的聚合操作包括 sum、max、min 、mean等。聚合函数一般有2个参数,它们的类型都是字符串:

一个用于指定mailbox中的字段名;

一个用于指示目标节点特征的字段名,例如dgl.function.sum('m', 'h')等价于如下所示的对接收到消息求和的用户定义函数:

import torch
def reduce_func(nodes):
     return {'h': torch.sum(nodes.mailbox['m'], dim=1)}

更新函数:接受一个如上所述的参数 nodes。此函数对 聚合函数 的聚合结果进行操作, 通常在消息传递的最后一步将其与节点的特征相结合,并将输出作为节点的新特征。

(2)内置&自定义函数

(1)命名空间dfl.function中实现了常用的(内置的)消息函数和聚合函数(能够自动处理维度广播),当然也可以自定义函数。

(2)(自定义)内置消息函数:可以是一元函数(dgl支持copy函数),也支持二元函数(dgl支持add、sub、mul、div、dot函数):

消息的内置函数的命名约定是u表示源节点,v表示目标节点,e表示边。

这些函数的参数是字符串,指示相应节点和边的输入和输出特征字段名。

ex:要对源节点的hu特征和目标节点的hv特征求和,然后将结果保存在边的he特征上:dgl.function.u_add_v('hu', 'hv', 'he');

  • 如下自定义消息函数和内置函数相同:
def message_func(edges):
     return {'he': edges.src['hu'] + edges.dst['hv']}

(3)在DGL中,也可以在不涉及消息传递的情况下,通过 apply_edges() 单独调用逐边计算。

  • apply_edges() 的参数是一个消息函数。
  • 在默认情况下,这个接口将更新所有的边。
  • 例如:
import dgl.function as fn
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))

(4)消息传递的高级APIupdate_all()

它在单个API调用里合并了消息生成、消息聚合和节点特征更新,为从整体上进行系统优化提供了空间。

update_all()的参数:一个消息、聚合、更新函数(可选,也可以在外面操作,dgl不推荐在update_all中指定更新函数)。

更新函数是一个可选择的参数,可以不使用,而是在update_all执行完后直接对节点特征进行操作;

因为更新函数通常可用纯张量操作实现,所以DGL不推荐在update_all中指定更新函数,如函数:

image.png

def updata_all_example(graph):
    # 在graph.ndata['ft']中存储结果
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
    # 在update_all外调用更新函数
    final_ft = graph.ndata['ft'] * 2
    return final_ft

ex:graph.update_all(fn.u_mul_e('ft', 'a', 'm')将源节点特征tf和边特征a相乘生成消息m,fn.sum('m', 'ft')再对所有消息求和来更新节点特征ft,再乘2后得到最终结果final_ft。调用后,中间消息m会被清除。

1.2 编写高效的消息传递代码

关于dgl内置函数是如何优化消息传递的内存消耗和计算速度的, 详见文字描述: DGL官方文档 ; 总结来说主要是合并内核, 并行逐边运算, 减少点边拷贝等; 如update_all()函数就是一个效率很高的接口; 如果确实需要使用apply_edges()函数在边上保存消息, 则内存占用会非常大;

(1)一个通过对节点特征降维来减少消息维度的示例:

拼接源节点与目标节点特征, 然后应用一个线性层: W × ( u ∣ ∣ v ) W\times (u||v)W×(u∣∣v);

这样源节点与目标节点特征维数较高, 而线性层输出维数较低;

代码示例:

import torch
import torch.nn as nn
linear = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim * 2)))
def concat_message_function(edges):
   return {'cat_feat': torch.cat([edges.src.ndata['feat'], edges.dst.ndata['feat']])}
g.apply_edges(concat_message_function)
g.edata['out'] = g.edata['cat_feat'] * linear 

image.png

import dgl.function as fn
linear_src = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim)))
linear_dst = nn.Parameter(torch.FloatTensor(size=(1, node_feat_dim)))
out_src = g.ndata['feat'] * linear_src
out_dst = g.ndata['feat'] * linear_dst
g.srcdata.update({'out_src': out_src})
g.dstdata.update({'out_dst': out_dst})
g.apply_edges(fn.u_add_v('out_src', 'out_dst', 'out'))

这两种方法数学上等价, 但后一种方法更加高效, 因为无需再边上保存feat_src和feat_dst, 空间占用小, 另外加法可以直接用内置函数u_add_v进行优化, 内置函数的效率一般比自定义函数要高。

1.3 在图的一部分上进行消息传递

如果用户只想更新图中部分节点,先将想处理的节点编号创建一个子图,然后对其调用update_all()(这也是小批量处理中的常见用法)。

nid = [0, 2, 3, 6, 7, 9]
sg = g.subgraph(nid)
sg.update_all(message_func, reduce_func, apply_node_func)

1.4 在消息传递中使用边的权重

常见的GNN建模做法:在消息聚合前使用边的权重,如GAT和一些GCN的变种。dgl的处理:

将权重存为边的特征;

在消息函数中用边的特征和源节点的特征相乘。

ex:假定下面的权重eweight是一个形状为(E, *)的张量,E是边的数量。权重存为边的特征,即eweight被用作边的权重(通常是一个标量)。

import dgl.function as fn
# 假定eweight是一个形状为(E, *)的张量,E是边的数量。
graph.edata['a'] = eweight
graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                 fn.sum('m', 'ft'))

1.5 在异构图上进行消息传递

本质上异构图的消息传递与同构图并没有太大区别,异构图上的消息传递可以分为两个部分:

对每个关系计算和聚合消息。

对每个结点聚合来自不同关系的消息。

DGLGraph.multi_update_all(etype_dict, cross_reducer, apply_node_func=None):

etype_dict: dict类型, 键为一种关系, 值为这种关系对应的update_all()的参数;

cross_reducer: str类型, 表示跨类型整合函数, 来指定整合不同关系聚合结果的方式, 可以是sum, min, max, mean, stack中之一;

在DGL中,对异构图进行消息传递的接口是 multi_update_all()。 multi_update_all() 接受一个字典。这个字典的每一个键值对里,键是一种关系, 值是这种关系对应 update_all() 的参数。 multi_update_all() 还接受一个字符串来表示跨类型整合函数,来指定整合不同关系聚合结果的方式。 这个整合方式可以是 sum、 min、 max、 mean 和 stack 中的一个。

import dgl.function as fn
for c_etype in G.canonical_etypes:
    srctype, etype, dsttype = c_etype
    Wh = self.weight[etype](feat_dict[srctype])
    # 把它存在图中用来做消息传递
    G.nodes[srctype].data['Wh_%s' % etype] = Wh
    # 指定每个关系的消息传递函数:(message_func, reduce_func).
    # 注意结果保存在同一个目标特征“h”,说明聚合是逐类进行的。
    funcs[etype] = (fn.copy_u('Wh_%s' % etype, 'm'), fn.mean('m', 'h'))
# 将每个类型消息聚合的结果相加。
G.multi_update_all(funcs, 'sum')
# 返回更新过的节点特征字典
return {ntype : G.nodes[ntype].data['h'] for ntype in G.ntypes}

小结消息传递的流程:

image.png

消息函数(message function):传递消息的目的是将节点计算时需要的信息传递给它,因此对每条边来说,每个源节点将会将自身的Embedding(e.src.data)和边的Embedding(edge.data)传递到目的节点;对于每个目的节点来说,它可能会受到多个源节点传过来的消息,它会将这些消息存储在"邮箱"中。

汇聚函数(reduce function):汇聚函数的目的是根据邻居传过来的消息更新跟新自身节点Embedding,对每个节点来说,它先从邮箱(v.mailbox[‘m’])中汇聚消息函数所传递过来的消息(message),并清空邮箱(v.mailbox[‘m’])内消息;然后该节点结合汇聚后的结果和该节点原Embedding,更新节点Embedding。

更新: message函数的参数是边,包括源节点,目标节点的特征信息,处理完的数据放置到节点的mailbox中。 聚合函数reduce_function或 apply_node函数作用于节点本身,即传入的参数是节点信息和节点的邮箱信息。

二、不带边权重的例子

2.1 消息传递框架

DGL遵循Gilmer等人提出的消息8传递框架,很多GNN模型能符合如下框架:

image.png

2.2 GraphSAGE的消息传递

如GraphSAGE可表示为:

image.png

我们可以看到消息传递是定向(有方向)的:从一个节点u发送到另一个节点v的消息不一定与从节点v发送到相反方向的节点u的消息相同。

DGL提供了GraphSAGE的实现dgl.nn.SAGEConv

import dgl.function as fn
class SAGEConv(nn.Module):
    """Graph convolution module used by the GraphSAGE model.
    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        # A linear submodule for projecting the input and neighbor feature to the output.
        self.linear = nn.Linear(in_feat * 2, out_feat)
    def forward(self, g, h):
        """Forward computation
        Parameters
        ----------
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        """
        with g.local_scope():
            g.ndata['h'] = h
            # update_all is a message passing API.
            g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

上述代码中的核心部分是g.update_all函数,该函数收集并平均相邻特征。这里有三个概念:

消息函数fn.copy_u('h','m'),它将名为“h”的节点特征复制为发送给邻居的消息

聚合函数fn.mean('m', 'h_N'),该函数对所有接收到的消息中名为’m’的信息进行平均,并将结果保存为新的节点特征’h_N’

update_all让DGL触发所有节点和边的消息函数和聚合函数

2.3 堆叠网络

然后我们可以堆叠自己的GraphSAGE卷积层以构成多层GraphSAGE网络:

import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = GraphSAGE(in_feats, h_feats)
        self.conv2 = GraphSAGE(h_feats, num_classes)
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

2.4 训练网络

import dgl.data
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]
def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0
    best_test_acc = 0
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']
    for e in range(200):
        # Forward
        logits = model(g, features)
        # Compute prediction
        pred = logits.argmax(1)
        # Compute loss
        # Note that we should only compute the losses of the nodes in the training set,
        # i.e. with train_mask 1.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())
        if e % 5 == 0:
            print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format(
                e, loss, val_acc, best_val_acc, test_acc, best_test_acc))
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

2.5 结果

Using backend: pytorch
  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
In epoch 0, loss: 1.951, val acc: 0.114 (best 0.114), test acc: 0.103 (best 0.103)
In epoch 5, loss: 1.900, val acc: 0.290 (best 0.292), test acc: 0.278 (best 0.277)
In epoch 10, loss: 1.790, val acc: 0.462 (best 0.462), test acc: 0.435 (best 0.435)
In epoch 15, loss: 1.614, val acc: 0.502 (best 0.502), test acc: 0.489 (best 0.489)
In epoch 20, loss: 1.372, val acc: 0.548 (best 0.548), test acc: 0.529 (best 0.529)
In epoch 25, loss: 1.087, val acc: 0.592 (best 0.592), test acc: 0.591 (best 0.591)
In epoch 30, loss: 0.798, val acc: 0.650 (best 0.650), test acc: 0.639 (best 0.639)
In epoch 35, loss: 0.547, val acc: 0.690 (best 0.690), test acc: 0.682 (best 0.682)
In epoch 40, loss: 0.358, val acc: 0.710 (best 0.710), test acc: 0.721 (best 0.721)
In epoch 45, loss: 0.230, val acc: 0.736 (best 0.736), test acc: 0.734 (best 0.734)
In epoch 50, loss: 0.149, val acc: 0.738 (best 0.738), test acc: 0.743 (best 0.744)
In epoch 55, loss: 0.099, val acc: 0.740 (best 0.740), test acc: 0.744 (best 0.743)
In epoch 60, loss: 0.068, val acc: 0.742 (best 0.742), test acc: 0.743 (best 0.745)
In epoch 65, loss: 0.048, val acc: 0.734 (best 0.742), test acc: 0.749 (best 0.745)
In epoch 70, loss: 0.036, val acc: 0.736 (best 0.742), test acc: 0.753 (best 0.745)
In epoch 75, loss: 0.028, val acc: 0.734 (best 0.742), test acc: 0.755 (best 0.745)
In epoch 80, loss: 0.023, val acc: 0.738 (best 0.742), test acc: 0.757 (best 0.745)
In epoch 85, loss: 0.019, val acc: 0.738 (best 0.742), test acc: 0.758 (best 0.745)
In epoch 90, loss: 0.017, val acc: 0.742 (best 0.742), test acc: 0.756 (best 0.745)
In epoch 95, loss: 0.015, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745)
In epoch 100, loss: 0.013, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745)
In epoch 105, loss: 0.012, val acc: 0.742 (best 0.742), test acc: 0.755 (best 0.745)
In epoch 110, loss: 0.011, val acc: 0.742 (best 0.742), test acc: 0.753 (best 0.745)
In epoch 115, loss: 0.010, val acc: 0.742 (best 0.742), test acc: 0.753 (best 0.745)
In epoch 120, loss: 0.009, val acc: 0.742 (best 0.742), test acc: 0.754 (best 0.745)
In epoch 125, loss: 0.008, val acc: 0.742 (best 0.742), test acc: 0.754 (best 0.745)
In epoch 130, loss: 0.008, val acc: 0.742 (best 0.742), test acc: 0.752 (best 0.745)
In epoch 135, loss: 0.007, val acc: 0.742 (best 0.742), test acc: 0.752 (best 0.745)
In epoch 140, loss: 0.007, val acc: 0.744 (best 0.744), test acc: 0.751 (best 0.751)
In epoch 145, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.751 (best 0.751)
In epoch 150, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.749 (best 0.751)
In epoch 155, loss: 0.006, val acc: 0.744 (best 0.744), test acc: 0.750 (best 0.751)
In epoch 160, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.750 (best 0.751)
In epoch 165, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751)
In epoch 170, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751)
In epoch 175, loss: 0.005, val acc: 0.742 (best 0.744), test acc: 0.752 (best 0.751)
In epoch 180, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751)
In epoch 185, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.753 (best 0.751)
In epoch 190, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.754 (best 0.751)
In epoch 195, loss: 0.004, val acc: 0.742 (best 0.744), test acc: 0.754 (best 0.751)

三、带边权重的例子

这里需要改的两个位置,g.update_all的2个参数,以及在Model中要设置边权重的传参。其他就没啥变化了。即使用带权平均聚合邻居表示,edata成员可以保存边权重(特征),这些特征也可以参与消息传递。

# data可以包含边特征信息,同时传递
class WeightedSAGEConv(nn.Module):
    """
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    """
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        # 将input和邻近节点特征映射到outpu线性子模块
        self.linear = nn.Linear(in_feat * 2, out_feat)
    def forward(self, g, h, w):
        """
        g : Graph
            The input graph.
        h : Tensor
            The input node feature.
        w : Tensor
            The edge weight.
        """
        with g.local_scope():
            g.ndata['h'] = h
            # 加入边的权重,进行消息传递和更新
            g.edata['w'] = w
            g.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N'))
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)
# 因为这个数据集中的图没有边的权值,
# 所以我们在模型的 forward 函数中手动将所有边的权值赋给1。
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)
    def forward(self, g, in_feat):
        # 3个参数,(g, h, w)即图,点特征,边权重
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
        h = F.relu(h)
        # 设置所有边的权重为1
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
        return h

四、DGL按照优先级高低排序推荐做法

直接调用dgl.nn模块;

使用dgl.nn.functional内置方法,适合一些简单操作,如为每个节点计算softmax;

使用update_all,内置的消息函数和聚合函数;

使用用户自定义的消息(message)函数和聚合(reduce)函数。

五、用户自定义函数

DGL允许用户自定义消息函数和聚合函数以获得最大的表达能力。以下是一个用户定义的消息函数,它等价于fn.u_mul_e('h', 'w', 'm')。

def u_mul_e_udf(edges):
    return {"m": edges.src["h"] * edges.data["w"]}

参数edges共有三个成员:src,data和dst,分别代表所有边的源节点特征,边特征和目标节点特征。

也可以编写自己的聚合函数。例如,下面的函数相当于内置的fn.sum(‘m’, ‘h’)函数,它对传入的消息求和:

def sum_udf(nodes):
    return {"h": nodes.mailbox["m"].sum(dim=1)} 
    # dim=1,按行求和

总之,DGL将按节点的度数对节点进行分组,对于每个组DGL将传入的消息沿着第2维度(按行)进行堆叠,然后沿第2个维度执行缩减(reduce)以聚合消息。

相关文章
|
30天前
|
机器学习/深度学习 数据可视化
KAN干翻MLP,开创神经网络新范式!一个数十年前数学定理,竟被MIT华人学者复活了
【10月更文挑战第12天】MIT华人学者提出了一种基于Kolmogorov-Arnold表示定理的新型神经网络——KAN。与传统MLP不同,KAN将可学习的激活函数放在权重上,使其在表达能力、准确性、可解释性和收敛速度方面表现出显著优势,尤其在处理高维数据时效果更佳。然而,KAN的复杂性也可能带来部署和维护的挑战。论文地址:https://arxiv.org/pdf/2404.19756
40 1
|
5月前
|
机器学习/深度学习 人工智能 自然语言处理
Transformer 能代替图神经网络吗?
Transformer模型的革新性在于其自注意力机制,广泛应用于多种任务,包括非原始设计领域。近期研究专注于Transformer的推理能力,特别是在图神经网络(GNN)上下文中。
96 5
|
4月前
|
机器学习/深度学习 搜索推荐 知识图谱
图神经网络加持,突破传统推荐系统局限!北大港大联合提出SelfGNN:有效降低信息过载与数据噪声影响
【7月更文挑战第22天】北大港大联手打造SelfGNN,一种结合图神经网络与自监督学习的推荐系统,专攻信息过载及数据噪声难题。SelfGNN通过短期图捕获实时用户兴趣,利用自增强学习提升模型鲁棒性,实现多时间尺度动态行为建模,大幅优化推荐准确度与时效性。经四大真实数据集测试,SelfGNN在准确性和抗噪能力上超越现有模型。尽管如此,高计算复杂度及对图构建质量的依赖仍是待克服挑战。[详细论文](https://arxiv.org/abs/2405.20878)。
80 5
|
4月前
|
机器学习/深度学习 PyTorch 算法框架/工具
图神经网络是一类用于处理图结构数据的神经网络。与传统的深度学习模型(如卷积神经网络CNN和循环神经网络RNN)不同,
图神经网络是一类用于处理图结构数据的神经网络。与传统的深度学习模型(如卷积神经网络CNN和循环神经网络RNN)不同,
|
4月前
|
机器学习/深度学习 编解码 数据可视化
图神经网络版本的Kolmogorov Arnold(KAN)代码实现和效果对比
目前我们看到有很多使用KAN替代MLP的实验,但是目前来说对于图神经网络来说还没有类似的实验,今天我们就来使用KAN创建一个图神经网络Graph Kolmogorov Arnold(GKAN),来测试下KAN是否可以在图神经网络方面有所作为。
187 0
|
5月前
|
机器学习/深度学习 数据采集 TensorFlow
使用Python实现深度学习模型:图神经网络(GNN)
使用Python实现深度学习模型:图神经网络(GNN)
264 1
|
6月前
|
机器学习/深度学习 自然语言处理 搜索推荐
【传知代码】图神经网络长对话理解-论文复现
在ACL2023会议上发表的论文《使用带有辅助跨模态交互的关系时态图神经网络进行对话理解》提出了一种新方法,名为correct,用于多模态情感识别。correct框架通过全局和局部上下文信息捕捉对话情感,同时有效处理跨模态交互和时间依赖。模型利用图神经网络结构,通过构建图来表示对话中的交互和时间关系,提高了情感预测的准确性。在IEMOCAP和CMU-MOSEI数据集上的实验结果证明了correct的有效性。源码和更多细节可在文章链接提供的附件中获取。
【传知代码】图神经网络长对话理解-论文复现
|
5月前
|
机器学习/深度学习 搜索推荐 PyTorch
【机器学习】图神经网络:深度解析图神经网络的基本构成和原理以及关键技术
【机器学习】图神经网络:深度解析图神经网络的基本构成和原理以及关键技术
1135 2
|
6月前
|
机器学习/深度学习 JSON PyTorch
图神经网络入门示例:使用PyTorch Geometric 进行节点分类
本文介绍了如何使用PyTorch处理同构图数据进行节点分类。首先,数据集来自Facebook Large Page-Page Network,包含22,470个页面,分为四类,具有不同大小的特征向量。为训练神经网络,需创建PyTorch Data对象,涉及读取CSV和JSON文件,处理不一致的特征向量大小并进行归一化。接着,加载边数据以构建图。通过`Data`对象创建同构图,之后数据被分为70%训练集和30%测试集。训练了两种模型:MLP和GCN。GCN在测试集上实现了80%的准确率,优于MLP的46%,展示了利用图信息的优势。
90 1
|
6月前
|
机器学习/深度学习 数据挖掘 算法框架/工具
想要了解图或图神经网络?没有比看论文更好的方式,面试阿里国际站运营一般会问什么
想要了解图或图神经网络?没有比看论文更好的方式,面试阿里国际站运营一般会问什么