图神经网络13-图注意力模型GAT网络详解

简介: 图神经网络13-图注意力模型GAT网络详解


论文摘要


图卷积发展至今,早期的进展可以归纳为谱图方法和非谱图方法,这两者都存在一些挑战性问题。

  • 谱图方法:学习滤波器主要基于图的拉普拉斯特征,图的拉普拉斯取决于图结构本身,因此在特定图结构上学习到的谱图模型无法直接应用到不同结构的图中。
  • 非谱图方法:对不同大小的邻域结构,像CNNs那样设计统一的卷积操作比较困难。

此外,图结构数据往往存在大量噪声,换句话说,节点之间的连接关系有时并没有特别重要,节点的不同邻居的相对重要性也有差异。


本文提出了图注意力网络(GAT),利用masked self-attention layer,通过堆叠网络层,获取每个节点的邻域特征,为邻域中的不同节点分配不同的权重。这样做的好处是不需要高成本的矩阵运算,也不用事先知道图结构信息。通过这种方式,GAT可以解决谱图方法存在的问题,同时也能应用于归纳学习和直推学习问题。


GAT模型结构


假设一个图有个节点,节点的维特征集合可以表示为

注意力层的目的是输出新的节点特征集合,

在这个过程中特征向量的维度可能会改变,即 为了保留足够的表达能力,将输入特征转化为高阶特征,至少需要一个可学习的线性变换。例如,对于节点,对它们的特征应用线性变换,从维转化为 维新特征为

上式在将输入特征运用线性变换转化为高阶特征后,使用self-attention为每个节点分配注意力(权重)。其中表示一个共享注意力机制:,用于计算注意力系数,也就是节点对节点的影响力系数(标量)。

上面的注意力计算考虑了图中任意两个节点,也就是说,图中每个节点对目标节点的影响都被考虑在内,这样就损失了图结构信息。论文中使用了masked attention,对于目标节点来说,只计算其邻域内的节点对目标节点的相关度(包括自身的影响)。

为了更好的在不同节点之间分配权重,我们需要将目标节点与所有邻居计算出来的相关度进行统一的归一化处理,这里用softmax归一化:

关于的选择,可以用向量的内积来定义一种无参形式的相关度计算,也可以定义成一种带参的神经网络层,只要满足,即输出一个标量值表示二者的相关度即可。在论文实验中,是一个单层前馈神经网络,参数为权重向量,使用负半轴斜率为0.2的LeakyReLU作为非线性激活函数:

其中表示拼接操作。完整的权重系数计算公式为:

得到归一化注意系数后,计算其对应特征的线性组合,通过非线性激活函数后,每个节点的最终输出特征向量为:


多头注意力机制


另外,本文使用多头注意力机制(multi-head attention)来稳定self-attention的学习过程,即对上式调用组相互独立的注意力机制,然后将输出结果拼接起来:

其中是拼接操作,是第组注意力机制计算出的权重系数,是对应的输入线性变换矩阵,最终输出的节点特征向量包含了个特征。为了减少输出的特征向量的维度,也可以将拼接操作替换为平均操作。

下面是的多头注意力机制示意图。不同颜色的箭头表示不同注意力的计算过程,每个邻居做三次注意力计算,每次attention计算就是一个普通的self-attention,输出一个,最后将三个不同的进行拼接或取平均,得到最终的


不同模型比较


  • GAT计算高效。self-attetion层可以在所有边上并行计算,输出特征可以在所有节点上并行计算;不需要特征分解或者其他内存耗费大的矩阵操作。单个head的GAT的时间复杂度为
  • 与GCN不同的是,GAT为同一邻域中的节点分配不同的重要性,提升了模型的性能。
  • 注意力机制以共享的方式应用于图中的所有边,因此它不依赖于对全局图结构的预先访问,也不依赖于对所有节点(特征)的预先访问(这是许多先前技术的限制)。
  • 不必要无向图。如果边不存在,可以忽略计算
  • 可以用于归纳学习;


评估


数据集

31.png


其中前三个引文网络用于直推学习,第四个蛋白质交互网络PPI用于归纳学习。


实验设置


  • 直推学习
  • 两层GAT模型,第一层多头注意力,输出特征维度(共64个特征),激活函数为指数线性单元(ELU);
  • 第二层单头注意力,计算个特征(为分类数),接softmax激活函数;
  • 为了处理小的训练集,模型中大量采用正则化方法,具体为L2正则化;
  • dropout;
  • 归纳学习:
  • 三层GAT模型,前两层多头注意力,输出特征维度(共1024个特征),激活函数为指数非线性单元(ELU);
  • 最后一层用于多标签分类,,每个头计算121个特征,后接logistic sigmoid激活函数;
  • 不使用正则化和dropout;
  • 使用了跨越中间注意力层的跳跃连接。
  • batch_size = 2 graph


实验结果


  • 不同数据集的分类准确率效果对比(Transductive)

    32.png


  • 数据集PPI上的F1效果(归纳学习)

33.png


  • 可视化

34.png


核心代码


GAT层代码:

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.leakyrelu = nn.LeakyReLU(self.alpha)
    def forward(self, h, adj):
        Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime
    def _prepare_attentional_mechanism_input(self, Wh):
        N = Wh.size()[0] # number of nodes
        # Below, two matrices are created that contain embeddings in their rows in different orders.
        # (e stands for embedding)
        # These are the rows of the first matrix (Wh_repeated_in_chunks): 
        # e1, e1, ..., e1,            e2, e2, ..., e2,            ..., eN, eN, ..., eN
        # '-------------' -> N times  '-------------' -> N times       '-------------' -> N times
        # 
        # These are the rows of the second matrix (Wh_repeated_alternating): 
        # e1, e2, ..., eN, e1, e2, ..., eN, ..., e1, e2, ..., eN 
        # '----------------------------------------------------' -> N times
        # 
        Wh_repeated_in_chunks = Wh.repeat_interleave(N, dim=0)
        Wh_repeated_alternating = Wh.repeat(N, 1)
        # Wh_repeated_in_chunks.shape == Wh_repeated_alternating.shape == (N * N, out_features)
        # The all_combination_matrix, created below, will look like this (|| denotes concatenation):
        # e1 || e1
        # e1 || e2
        # e1 || e3
        # ...
        # e1 || eN
        # e2 || e1
        # e2 || e2
        # e2 || e3
        # ...
        # e2 || eN
        # ...
        # eN || e1
        # eN || e2
        # eN || e3
        # ...
        # eN || eN
        all_combinations_matrix = torch.cat([Wh_repeated_in_chunks, Wh_repeated_alternating], dim=1)
        # all_combinations_matrix.shape == (N * N, 2 * out_features)
        return all_combinations_matrix.view(N, N, 2 * self.out_features)
    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


GAT模型

import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphAttentionLayer, SpGraphAttentionLayer
class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout
        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)
        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)
    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)


参考文章


图神经网络:图注意力网络(GAT) https://jjzhou012.github.io/blog/2020/01/28/Graph-Attention-Networks.html


相关文章
|
14天前
|
机器学习/深度学习 人工智能 算法
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
鸟类识别系统。本系统采用Python作为主要开发语言,通过使用加利福利亚大学开源的200种鸟类图像作为数据集。使用TensorFlow搭建ResNet50卷积神经网络算法模型,然后进行模型的迭代训练,得到一个识别精度较高的模型,然后在保存为本地的H5格式文件。在使用Django开发Web网页端操作界面,实现用户上传一张鸟类图像,识别其名称。
60 12
鸟类识别系统Python+卷积神经网络算法+深度学习+人工智能+TensorFlow+ResNet50算法模型+图像识别
|
6天前
|
机器学习/深度学习 数据采集 网络安全
使用Python实现深度学习模型:智能网络安全威胁检测
使用Python实现深度学习模型:智能网络安全威胁检测
26 5
|
4天前
|
机器学习/深度学习 算法 搜索推荐
图神经网络综述:模型与应用
图神经网络综述:模型与应用
|
8天前
|
存储 机器人 Linux
Netty(二)-服务端网络编程常见网络IO模型讲解
Netty(二)-服务端网络编程常见网络IO模型讲解
|
4天前
|
存储 安全 网络安全
云计算与网络安全:技术融合下的信息安全新挑战
【9月更文挑战第29天】在数字化浪潮的推动下,云计算服务如雨后春笋般涌现,为各行各业提供了前所未有的便利和效率。然而,随着数据和服务的云端化,网络安全问题也日益凸显,成为制约云计算发展的关键因素之一。本文将从技术角度出发,探讨云计算环境下网络安全的重要性,分析云服务中存在的安全风险,并提出相应的防护措施。我们将通过实际案例,揭示如何在享受云计算带来的便捷的同时,确保数据的安全性和完整性。
|
4天前
|
SQL 安全 算法
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【9月更文挑战第29天】随着互联网的普及,网络安全问题日益严重。本文将介绍网络安全漏洞、加密技术以及安全意识等方面的内容,帮助读者了解网络安全的重要性,提高自身的网络安全意识。
|
4天前
|
存储 SQL 安全
网络安全与信息安全:构建安全防线的关键策略
本文深入探讨了网络安全与信息安全领域的核心要素,包括网络安全漏洞、加密技术以及安全意识的重要性。通过对这些关键领域的分析,旨在为读者提供一套综合性的防护策略,帮助企业和个人在日益复杂的网络环境中保障数据安全。
15 4
|
3天前
|
SQL 安全 程序员
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【9月更文挑战第30天】在数字化时代,网络安全和信息安全已成为全球关注的焦点。本文将探讨网络安全漏洞、加密技术以及提升安全意识的重要性。我们将通过代码示例,深入理解网络安全的基础知识,包括常见的网络攻击手段、防御策略和加密技术的实际应用。同时,我们还将讨论如何提高个人和企业的安全意识,以应对日益复杂的网络安全威胁。
|
2天前
|
SQL 安全 算法
数字时代的守护者:网络安全与信息安全的现代策略
【9月更文挑战第31天】在数字化时代,网络安全与信息安全成为保护个人隐私和企业资产的关键。本文将深入探讨网络安全漏洞的成因、加密技术的应用以及提升安全意识的重要性,旨在为读者提供防范网络威胁的策略和知识分享。
17 7
|
2天前
|
存储 安全 网络安全
网络安全与信息安全:关于网络安全漏洞、加密技术、安全意识等方面的知识分享
【9月更文挑战第31天】在数字化时代,网络安全和信息安全成为了我们生活中不可或缺的一部分。本文将从网络安全漏洞、加密技术和安全意识等方面进行知识分享,帮助读者更好地了解和保护自己的网络安全。
下一篇
无影云桌面