微软工程师用PyTorch实现图注意力网络,可视化效果惊艳

简介: 近日,一个关于图注意力网络可视化的项目吸引了大批研究人员的兴趣,上线仅仅一天,收获 200+ 星。该项目是关于用 PyTorch 实现的图注意力网络(GAT),包括易于理解的可视化。

微信图片_20211205005413.jpg


项目地址:https://github.com/gordicaleksa/pytorch-GAT


在正式介绍项目之前,先提一下图神经网络(GNN)。GNN 是一类基于深度学习的处理图域信息的方法。由于其较好的性能和可解释性,GNN 最近已成为一种广泛应用的图分析方法。现已广泛应用于计算生物学、计算药理学、推荐系统等。

GNN 把深度学习应用到图结构 (Graph) 中,其中的图卷积网络 GCN 可以在 Graph 上进行卷积操作,但是 GCN 存在一些缺陷。因此,Bengio 团队在三年前提出了图注意力网络(GAT,Graph Attention Network) ,旨在解决 GCN 存在的问题。


GAT 是空间(卷积)GNN 的代表。由于 CNNs 在计算机视觉领域取得了巨大的成功,研究人员决定将其推广到图形上,因此 GAT 应运而生。


现在,有人用 PyTorch 实现了 GAT 可视化。我们来看看该项目是如何实现的。


微信图片_20211205005434.jpg


可视化


Cora 可视化


说到 GNN,就不得不介绍一下 Cora 数据集。Cora 数据集由许多机器学习论文组成,是近年来图深度学习很喜欢使用的数据集。Cora 中的节点代表研究论文,链接是这些论文之间的引用。项目作者添加了一个用于可视化 Cora 和进行基本网络分析的实用程序。Cora 如下图所示:


微信图片_20211205005510.jpg

节点大小对应于其等级(即进出边的数量)。边的粗细大致对应于边的「popular」或「连接」程度。以下是显示 Cora 上等级(进出边的数量)分布的图:


微信图片_20211205005526.jpg

进和出的等级图是一样的,因为处理的是无向图。在底部的图(等级分布)上,我们可以看到一个有趣的峰值发生在 [2,4] 范围内。这意味着多数节点有少量的边,但是有 1 个节点有 169 条边(绿色大节点)。


注意力可视化


有了一个训练好的 GAT 模型以后,我们就可以将某些节点所学的注意力可视化。节点利用注意力来决定如何聚合周围的节点,如下图所示:


微信图片_20211205005559.jpg


这是 Cora 节点中边数最多的节点之一(引用)。颜色表示同一类的节点。


熵直方图


另一种理解 GAT 没有在 Cora 上学习注意力模式 (即它在学习常量注意力) 的方法是,将节点邻域的注意力权重视为概率分布,计算熵,并在每个节点邻域积累信息。


我们希望 GAT 的注意力分布有偏差。你可以看到橙色的直方图是理想均匀分布的样子,而浅蓝色的是学习后的分布,它们是完全一样的。


微信图片_20211205005626.jpg


分析 Cora 嵌入空间 (t-SNE)


GAT 的输出张量为 shape=(2708,7),其中 2708 是 Cora 中的节点数,7 是类数。用 t-SNE 把这些 7 维向量投影成 2D,得到:



微信图片_20211205005702.jpg


使用方法


方法 1:Jupyter Notebook


只需从 Anaconda 控制台运行 Jupyter Notebook,它将在你的默认浏览器中打开 session。打开 The Annotated GAT.ipynb 即可开始。

注意,如果你得到了 DLL load failed while importing win32api: The specified module could not be found,只需要 pip uninstall pywin32,或者 pip install pywin32、onda install pywin32。


方法 2:使用你选择的 IDE


如果使用自己选择的 IDE,只需要将 Python 环境和设置部分连接起来。


训练 GAT


在 Cora 上训练 GAT 所需的一切都已经设置好了,运行时只需调用 python training_script.py

此外,你还可以:

  • 添加 --should_visualize - 以可视化你的图形数据
  • 在数据的测试部分添加 --should_test - 以评估 GAT
  • 添加 --enable_tensorboard - 开始保存度量标准(准确率、损失)


代码部分的注释很完善,因此你可以了解到训练本身是如何运行的。

该脚本将:

  • 将 checkpoint* .pth 模型转储到 models/checkpoints/
  • 将 final* .pth 模型转储到 models/binaries/
  • 将度量标准保存到中 runs/,只需 tensorboard --logdir=runs 在 Anaconda 中运行即可将其可视化
  • 定期将一些训练元数据写入控制台


通过 tensorboard --logdir=runs 在控制台中调用,并将 http://localhost:6006/URL 粘贴到浏览器中,可以在训练过程中将度量标准可视化:


1638637096(1).png


可视化工具


如果要可视化 t-SNE 嵌入,请注意或嵌入该 visualize_gat_properties 函数的注释,并设置 visualization_type 为:

  • VisualizationType.ATTENTION - 如果希望可视化节点附近的注意力
  • VisualizationType.EMBEDDING - 如果希望可视化嵌入(通过 t-SNE)
  • VisualizationType.ENTROPY - 如果想可视化熵直方图


然后,你就得到了一张优秀的可视化效果图(VisualizationType.ATTENTION 可选):


微信图片_20211205005847.jpg

微信图片_20211205005849.jpg


硬件需求


GAT 不需要那种很强的硬件资源,尤其是如果你只想运行 Cora 的话,有 2GB 以上的 GPU 就可以了。

  • 在 RTX 2080 GPU 上训练 GAT 大约需要 10 秒;
  • 保留 1.5 GB 的 VRAM 内存(PyTorch 的缓存开销,为实际张量分配的内存少得多);
  • 模型本身只有 365 KB。


如果你想了解更多关于 GAT 的内容,请点击下方视频:


1638637167(1).png

点击查看原视频链接


相关文章
|
5月前
|
机器学习/深度学习 自然语言处理 PyTorch
Transformer自回归关键技术:掩码注意力原理与PyTorch完整实现
掩码注意力是生成模型的核心,通过上三角掩码限制模型仅关注当前及之前token,确保自回归因果性。相比BERT的双向注意力,它实现单向生成,是GPT等模型逐词预测的关键机制,核心仅需一步`masked_fill_`操作。
487 0
Transformer自回归关键技术:掩码注意力原理与PyTorch完整实现
|
机器学习/深度学习 计算机视觉
RT-DETR改进策略【Neck】| ASF-YOLO 注意力尺度序列融合模块改进颈部网络,提高小目标检测精度
RT-DETR改进策略【Neck】| ASF-YOLO 注意力尺度序列融合模块改进颈部网络,提高小目标检测精度
421 3
RT-DETR改进策略【Neck】| ASF-YOLO 注意力尺度序列融合模块改进颈部网络,提高小目标检测精度
|
5月前
|
机器学习/深度学习 算法 PyTorch
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
【Pytorch框架搭建神经网络】基于DQN算法、优先级采样的DQN算法、DQN + 人工势场的避障控制研究(Python代码实现)
155 1
|
5月前
|
机器学习/深度学习 算法 PyTorch
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
【DQN实现避障控制】使用Pytorch框架搭建神经网络,基于DQN算法、优先级采样的DQN算法、DQN + 人工势场实现避障控制研究(Matlab、Python实现)
249 0
|
9月前
|
机器学习/深度学习 PyTorch 算法框架/工具
基于Pytorch 在昇腾上实现GCN图神经网络
本文详细讲解了如何在昇腾平台上使用PyTorch实现图神经网络(GCN)对Cora数据集进行分类训练。内容涵盖GCN背景、模型特点、网络架构剖析及实战分析。GCN通过聚合邻居节点信息实现“卷积”操作,适用于非欧氏结构数据。文章以两层GCN模型为例,结合Cora数据集(2708篇科学出版物,1433个特征,7种类别),展示了从数据加载到模型训练的完整流程。实验在NPU上运行,设置200个epoch,最终测试准确率达0.8040,内存占用约167M。
基于Pytorch 在昇腾上实现GCN图神经网络
|
9月前
|
机器学习/深度学习 算法 PyTorch
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
深度学习近年来在多个领域取得了显著进展,但其核心组件——人工神经元和反向传播算法自提出以来鲜有根本性突破。穿孔反向传播(Perforated Backpropagation)技术通过引入“树突”机制,模仿生物神经元的计算能力,实现了对传统神经元的增强。该技术利用基于协方差的损失函数训练树突节点,使其能够识别神经元分类中的异常模式,从而提升整体网络性能。实验表明,该方法不仅可提高模型精度(如BERT模型准确率提升3%-17%),还能实现高效模型压缩(参数减少44%而无性能损失)。这一革新为深度学习的基础构建模块带来了新的可能性,尤其适用于边缘设备和大规模模型优化场景。
411 16
Perforated Backpropagation:神经网络优化的创新技术及PyTorch使用指南
|
9月前
|
机器学习/深度学习 搜索推荐 PyTorch
基于昇腾用PyTorch实现CTR模型DIN(Deep interest Netwok)网络
本文详细讲解了如何在昇腾平台上使用PyTorch训练推荐系统中的经典模型DIN(Deep Interest Network)。主要内容包括:DIN网络的创新点与架构剖析、Activation Unit和Attention模块的实现、Amazon-book数据集的介绍与预处理、模型训练过程定义及性能评估。通过实战演示,利用Amazon-book数据集训练DIN模型,最终评估其点击率预测性能。文中还提供了代码示例,帮助读者更好地理解每个步骤的实现细节。
|
9月前
|
机器学习/深度学习 自然语言处理 PyTorch
基于Pytorch Gemotric在昇腾上实现GAT图神经网络
本实验基于昇腾平台,使用PyTorch实现图神经网络GAT(Graph Attention Networks)在Pubmed数据集上的分类任务。内容涵盖GAT网络的创新点分析、图注意力机制原理、多头注意力机制详解以及模型代码实战。实验通过两层GAT网络对Pubmed数据集进行训练,验证模型性能,并展示NPU上的内存使用情况。最终,模型在测试集上达到约36.60%的准确率。
|
9月前
|
算法 PyTorch 算法框架/工具
PyTorch 实现FCN网络用于图像语义分割
本文详细讲解了在昇腾平台上使用PyTorch实现FCN(Fully Convolutional Networks)网络在VOC2012数据集上的训练过程。内容涵盖FCN的创新点分析、网络架构解析、代码实现以及端到端训练流程。重点包括全卷积结构替换全连接层、多尺度特征融合、跳跃连接和反卷积操作等技术细节。通过定义VOCSegDataset类处理数据集,构建FCN8s模型并完成训练与测试。实验结果展示了模型在图像分割任务中的应用效果,同时提供了内存使用优化的参考。
|
9月前
|
机器学习/深度学习 算法 PyTorch
基于Pytorch Gemotric在昇腾上实现GraphSage图神经网络
本实验基于PyTorch Geometric,在昇腾平台上实现GraphSAGE图神经网络,使用CiteSeer数据集进行分类训练。内容涵盖GraphSAGE的创新点、算法原理、网络架构及实战分析。GraphSAGE通过采样和聚合节点邻居特征,支持归纳式学习,适用于未见节点的表征生成。实验包括模型搭建、训练与验证,并在NPU上运行,最终测试准确率达0.665。

热门文章

最新文章

推荐镜像

更多