图神经网络13-图注意力模型GAT网络详解

简介: 图神经网络13-图注意力模型GAT网络详解


论文摘要


图卷积发展至今,早期的进展可以归纳为谱图方法和非谱图方法,这两者都存在一些挑战性问题。

  • 谱图方法:学习滤波器主要基于图的拉普拉斯特征,图的拉普拉斯取决于图结构本身,因此在特定图结构上学习到的谱图模型无法直接应用到不同结构的图中。
  • 非谱图方法:对不同大小的邻域结构,像CNNs那样设计统一的卷积操作比较困难。

此外,图结构数据往往存在大量噪声,换句话说,节点之间的连接关系有时并没有特别重要,节点的不同邻居的相对重要性也有差异。


本文提出了图注意力网络(GAT),利用masked self-attention layer,通过堆叠网络层,获取每个节点的邻域特征,为邻域中的不同节点分配不同的权重。这样做的好处是不需要高成本的矩阵运算,也不用事先知道图结构信息。通过这种方式,GAT可以解决谱图方法存在的问题,同时也能应用于归纳学习和直推学习问题。


GAT模型结构


假设一个图有个节点,节点的维特征集合可以表示为

注意力层的目的是输出新的节点特征集合,

在这个过程中特征向量的维度可能会改变,即 为了保留足够的表达能力,将输入特征转化为高阶特征,至少需要一个可学习的线性变换。例如,对于节点,对它们的特征应用线性变换,从维转化为 维新特征为

上式在将输入特征运用线性变换转化为高阶特征后,使用self-attention为每个节点分配注意力(权重)。其中表示一个共享注意力机制:,用于计算注意力系数,也就是节点对节点的影响力系数(标量)。

上面的注意力计算考虑了图中任意两个节点,也就是说,图中每个节点对目标节点的影响都被考虑在内,这样就损失了图结构信息。论文中使用了masked attention,对于目标节点来说,只计算其邻域内的节点对目标节点的相关度(包括自身的影响)。

为了更好的在不同节点之间分配权重,我们需要将目标节点与所有邻居计算出来的相关度进行统一的归一化处理,这里用softmax归一化:

关于的选择,可以用向量的内积来定义一种无参形式的相关度计算,也可以定义成一种带参的神经网络层,只要满足,即输出一个标量值表示二者的相关度即可。在论文实验中,是一个单层前馈神经网络,参数为权重向量,使用负半轴斜率为0.2的LeakyReLU作为非线性激活函数:

其中表示拼接操作。完整的权重系数计算公式为:

得到归一化注意系数后,计算其对应特征的线性组合,通过非线性激活函数后,每个节点的最终输出特征向量为:


多头注意力机制


另外,本文使用多头注意力机制(multi-head attention)来稳定self-attention的学习过程,即对上式调用组相互独立的注意力机制,然后将输出结果拼接起来:

其中是拼接操作,是第组注意力机制计算出的权重系数,是对应的输入线性变换矩阵,最终输出的节点特征向量包含了个特征。为了减少输出的特征向量的维度,也可以将拼接操作替换为平均操作。

下面是的多头注意力机制示意图。不同颜色的箭头表示不同注意力的计算过程,每个邻居做三次注意力计算,每次attention计算就是一个普通的self-attention,输出一个,最后将三个不同的进行拼接或取平均,得到最终的


不同模型比较


  • GAT计算高效。self-attetion层可以在所有边上并行计算,输出特征可以在所有节点上并行计算;不需要特征分解或者其他内存耗费大的矩阵操作。单个head的GAT的时间复杂度为
  • 与GCN不同的是,GAT为同一邻域中的节点分配不同的重要性,提升了模型的性能。
  • 注意力机制以共享的方式应用于图中的所有边,因此它不依赖于对全局图结构的预先访问,也不依赖于对所有节点(特征)的预先访问(这是许多先前技术的限制)。
  • 不必要无向图。如果边不存在,可以忽略计算
  • 可以用于归纳学习;


评估


数据集

31.png


其中前三个引文网络用于直推学习,第四个蛋白质交互网络PPI用于归纳学习。


实验设置


  • 直推学习
  • 两层GAT模型,第一层多头注意力,输出特征维度(共64个特征),激活函数为指数线性单元(ELU);
  • 第二层单头注意力,计算个特征(为分类数),接softmax激活函数;
  • 为了处理小的训练集,模型中大量采用正则化方法,具体为L2正则化;
  • dropout;
  • 归纳学习:
  • 三层GAT模型,前两层多头注意力,输出特征维度(共1024个特征),激活函数为指数非线性单元(ELU);
  • 最后一层用于多标签分类,,每个头计算121个特征,后接logistic sigmoid激活函数;
  • 不使用正则化和dropout;
  • 使用了跨越中间注意力层的跳跃连接。
  • batch_size = 2 graph


实验结果


  • 不同数据集的分类准确率效果对比(Transductive)

    32.png


  • 数据集PPI上的F1效果(归纳学习)

33.png


  • 可视化

34.png


核心代码


GAT层代码:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
    def forward(self, h, adj):
        Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime
    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0] # number of nodes
        # Below, two matrices are created that contain embeddings in their rows in different orders.
        # (e stands for embedding)
        # These are the rows of the first matrix (Wh_repeated_in_chunks): 
        # e1, e1, ..., e1,            e2, e2, ..., e2,            ..., eN, eN, ..., eN
        # '-------------' -> N times  '-------------' -> N times       '-------------' -> N times
        # 
        # These are the rows of the second matrix (Wh_repeated_alternating): 
        # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN 
        # '----------------------------------------------------' -> N times
        # 
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        Wh_repeated_alternating = Wh.repeat(N, 1)
        # Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features)
        # The all_combination_matrix, created below, will look like this (|| denotes concatenation):
        # e1 || e1
        # e1 || e2
        # e1 || e3
        # ...
        # e1 || eN
        # e2 || e1
        # e2 || e2
        # e2 || e3
        # ...
        # e2 || eN
        # ...
        # eN || e1
        # eN || e2
        # eN || e3
        # ...
        # eN || eN
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        # all_combinations_matrix.shape == (N * N, 2 * out_features)
        return all_combinations_matrix.view(N, N, 2 * self.out_features)
    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


GAT模型

import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphAttentionLayer, SpGraphAttentionLayer
class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout
        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)
        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)


参考文章


图神经网络:图注意力网络(GAT) https://jjzhou012.github.io/blog/2020/01/28/Graph-Attention-Networks.html


相关文章
|
3月前
|
C++
基于Reactor模型的高性能网络库之地址篇
这段代码定义了一个 InetAddress 类,是 C++ 网络编程中用于封装 IPv4 地址和端口的常见做法。该类的主要作用是方便地表示和操作一个网络地址(IP + 端口)
192 58
|
3月前
|
网络协议 算法 Java
基于Reactor模型的高性能网络库之Tcpserver组件-上层调度器
TcpServer 是一个用于管理 TCP 连接的类,包含成员变量如事件循环(EventLoop)、连接池(ConnectionMap)和回调函数等。其主要功能包括监听新连接、设置线程池、启动服务器及处理连接事件。通过 Acceptor 接收新连接,并使用轮询算法将连接分配给子事件循环(subloop)进行读写操作。调用链从 start() 开始,经由线程池启动和 Acceptor 监听,最终由 TcpConnection 管理具体连接的事件处理。
77 2
|
3月前
基于Reactor模型的高性能网络库之Tcpconnection组件
TcpConnection 由 subLoop 管理 connfd,负责处理具体连接。它封装了连接套接字,通过 Channel 监听可读、可写、关闭、错误等
97 1
|
3月前
|
JSON 监控 网络协议
干货分享“对接的 API 总是不稳定,网络分层模型” 看电商 API 故障的本质
本文从 OSI 七层网络模型出发,深入剖析电商 API 不稳定的根本原因,涵盖物理层到应用层的典型故障与解决方案,结合阿里、京东等大厂架构,详解如何构建高稳定性的电商 API 通信体系。
|
19天前
|
机器学习/深度学习 数据采集 人工智能
深度学习实战指南:从神经网络基础到模型优化的完整攻略
🌟 蒋星熠Jaxonic,AI探索者。深耕深度学习,从神经网络到Transformer,用代码践行智能革命。分享实战经验,助你构建CV、NLP模型,共赴二进制星辰大海。
|
20天前
|
机器学习/深度学习 传感器 算法
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
【无人车路径跟踪】基于神经网络的数据驱动迭代学习控制(ILC)算法,用于具有未知模型和重复任务的非线性单输入单输出(SISO)离散时间系统的无人车的路径跟踪(Matlab代码实现)
|
5月前
|
域名解析 网络协议 安全
计算机网络TCP/IP四层模型
本文介绍了TCP/IP模型的四层结构及其与OSI模型的对比。网络接口层负责物理网络接口,处理MAC地址和帧传输;网络层管理IP地址和路由选择,确保数据包准确送达;传输层提供端到端通信,支持可靠(TCP)或不可靠(UDP)传输;应用层直接面向用户,提供如HTTP、FTP等服务。此外,还详细描述了数据封装与解封装过程,以及两模型在层次划分上的差异。
831 13
|
5月前
|
网络协议 中间件 网络安全
计算机网络OSI七层模型
OSI模型分为七层,各层功能明确:物理层传输比特流,数据链路层负责帧传输,网络层处理数据包路由,传输层确保端到端可靠传输,会话层管理会话,表示层负责数据格式转换与加密,应用层提供网络服务。数据在传输中经过封装与解封装过程。OSI模型优点包括标准化、模块化和互操作性,但也存在复杂性高、效率较低及实用性不足的问题,在实际中TCP/IP模型更常用。
615 10
|
1月前
|
机器学习/深度学习 并行计算 算法
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
【CPOBP-NSWOA】基于豪冠猪优化BP神经网络模型的多目标鲸鱼寻优算法研究(Matlab代码实现)
|
3月前
基于Reactor模型的高性能网络库之Poller(EpollPoller)组件
封装底层 I/O 多路复用机制(如 epoll)的抽象类 Poller,提供统一接口支持多种实现。Poller 是一个抽象基类,定义了 Channel 管理、事件收集等核心功能,并与 EventLoop 绑定。其子类 EPollPoller 实现了基于 epoll 的具体操作,包括事件等待、Channel 更新和删除等。通过工厂方法可创建默认的 Poller 实例,实现多态调用。
220 60