微软工程师用PyTorch实现图注意力网络,可视化效果惊艳

简介: 近日,一个关于图注意力网络可视化的项目吸引了大批研究人员的兴趣,上线仅仅一天,收获 200+ 星。该项目是关于用 PyTorch 实现的图注意力网络(GAT),包括易于理解的可视化。

微信图片_20211205005413.jpg


项目地址:https://github.com/gordicaleksa/pytorch-GAT


在正式介绍项目之前,先提一下图神经网络(GNN)。GNN 是一类基于深度学习的处理图域信息的方法。由于其较好的性能和可解释性,GNN 最近已成为一种广泛应用的图分析方法。现已广泛应用于计算生物学、计算药理学、推荐系统等。

GNN 把深度学习应用到图结构 (Graph) 中,其中的图卷积网络 GCN 可以在 Graph 上进行卷积操作,但是 GCN 存在一些缺陷。因此,Bengio 团队在三年前提出了图注意力网络(GAT,Graph Attention Network) ,旨在解决 GCN 存在的问题。


GAT 是空间(卷积)GNN 的代表。由于 CNNs 在计算机视觉领域取得了巨大的成功,研究人员决定将其推广到图形上,因此 GAT 应运而生。


现在,有人用 PyTorch 实现了 GAT 可视化。我们来看看该项目是如何实现的。


微信图片_20211205005434.jpg


可视化


Cora 可视化


说到 GNN,就不得不介绍一下 Cora 数据集。Cora 数据集由许多机器学习论文组成,是近年来图深度学习很喜欢使用的数据集。Cora 中的节点代表研究论文,链接是这些论文之间的引用。项目作者添加了一个用于可视化 Cora 和进行基本网络分析的实用程序。Cora 如下图所示:


微信图片_20211205005510.jpg

节点大小对应于其等级(即进出边的数量)。边的粗细大致对应于边的「popular」或「连接」程度。以下是显示 Cora 上等级(进出边的数量)分布的图:


微信图片_20211205005526.jpg

进和出的等级图是一样的,因为处理的是无向图。在底部的图(等级分布)上,我们可以看到一个有趣的峰值发生在 [2,4] 范围内。这意味着多数节点有少量的边,但是有 1 个节点有 169 条边(绿色大节点)。


注意力可视化


有了一个训练好的 GAT 模型以后,我们就可以将某些节点所学的注意力可视化。节点利用注意力来决定如何聚合周围的节点,如下图所示:


微信图片_20211205005559.jpg


这是 Cora 节点中边数最多的节点之一(引用)。颜色表示同一类的节点。


熵直方图


另一种理解 GAT 没有在 Cora 上学习注意力模式 (即它在学习常量注意力) 的方法是,将节点邻域的注意力权重视为概率分布,计算熵,并在每个节点邻域积累信息。


我们希望 GAT 的注意力分布有偏差。你可以看到橙色的直方图是理想均匀分布的样子,而浅蓝色的是学习后的分布,它们是完全一样的。


微信图片_20211205005626.jpg


分析 Cora 嵌入空间 (t-SNE)


GAT 的输出张量为 shape=(2708,7),其中 2708 是 Cora 中的节点数,7 是类数。用 t-SNE 把这些 7 维向量投影成 2D,得到:



微信图片_20211205005702.jpg


使用方法


方法 1:Jupyter Notebook


只需从 Anaconda 控制台运行 Jupyter Notebook,它将在你的默认浏览器中打开 session。打开 The Annotated GAT.ipynb 即可开始。

注意,如果你得到了 DLL load failed while importing win32api: The specified module could not be found,只需要 pip uninstall pywin32,或者 pip install pywin32、onda install pywin32。


方法 2:使用你选择的 IDE


如果使用自己选择的 IDE,只需要将 Python 环境和设置部分连接起来。


训练 GAT


在 Cora 上训练 GAT 所需的一切都已经设置好了,运行时只需调用 python training_script.py

此外,你还可以:

  • 添加 --should_visualize - 以可视化你的图形数据
  • 在数据的测试部分添加 --should_test - 以评估 GAT
  • 添加 --enable_tensorboard - 开始保存度量标准(准确率、损失)


代码部分的注释很完善,因此你可以了解到训练本身是如何运行的。

该脚本将:

  • 将 checkpoint* .pth 模型转储到 models/checkpoints/
  • 将 final* .pth 模型转储到 models/binaries/
  • 将度量标准保存到中 runs/,只需 tensorboard --logdir=runs 在 Anaconda 中运行即可将其可视化
  • 定期将一些训练元数据写入控制台


通过 tensorboard --logdir=runs 在控制台中调用,并将 http://localhost:6006/URL 粘贴到浏览器中,可以在训练过程中将度量标准可视化:


1638637096(1).png


可视化工具


如果要可视化 t-SNE 嵌入,请注意或嵌入该 visualize_gat_properties 函数的注释,并设置 visualization_type 为:

  • VisualizationType.ATTENTION - 如果希望可视化节点附近的注意力
  • VisualizationType.EMBEDDING - 如果希望可视化嵌入(通过 t-SNE)
  • VisualizationType.ENTROPY - 如果想可视化熵直方图


然后,你就得到了一张优秀的可视化效果图(VisualizationType.ATTENTION 可选):


微信图片_20211205005847.jpg

微信图片_20211205005849.jpg


硬件需求


GAT 不需要那种很强的硬件资源,尤其是如果你只想运行 Cora 的话,有 2GB 以上的 GPU 就可以了。

  • 在 RTX 2080 GPU 上训练 GAT 大约需要 10 秒;
  • 保留 1.5 GB 的 VRAM 内存(PyTorch 的缓存开销,为实际张量分配的内存少得多);
  • 模型本身只有 365 KB。


如果你想了解更多关于 GAT 的内容,请点击下方视频:


1638637167(1).png

点击查看原视频链接


相关文章
|
25天前
|
机器学习/深度学习 PyTorch 算法框架/工具
CNN中的注意力机制综合指南:从理论到Pytorch代码实现
注意力机制已成为深度学习模型的关键组件,尤其在卷积神经网络(CNN)中发挥了重要作用。通过使模型关注输入数据中最相关的部分,注意力机制显著提升了CNN在图像分类、目标检测和语义分割等任务中的表现。本文将详细介绍CNN中的注意力机制,包括其基本概念、不同类型(如通道注意力、空间注意力和混合注意力)以及实际实现方法。此外,还将探讨注意力机制在多个计算机视觉任务中的应用效果及其面临的挑战。无论是图像分类还是医学图像分析,注意力机制都能显著提升模型性能,并在不断发展的深度学习领域中扮演重要角色。
59 10
|
1月前
|
机器学习/深度学习 PyTorch 算法框架/工具
PyTorch 中的动态计算图:实现灵活的神经网络架构
【8月更文第27天】PyTorch 是一款流行的深度学习框架,它以其灵活性和易用性而闻名。与 TensorFlow 等其他框架相比,PyTorch 最大的特点之一是支持动态计算图。这意味着开发者可以在运行时定义网络结构,这为构建复杂的模型提供了极大的便利。本文将深入探讨 PyTorch 中动态计算图的工作原理,并通过一些示例代码展示如何利用这一特性来构建灵活的神经网络架构。
68 1
|
1月前
|
应用服务中间件 nginx Docker
【与时俱进】网络工程师必备技能:Docker基础入门指南,助你轻松应对新时代挑战!
【8月更文挑战第22天】随着容器技术的发展,Docker已成为开发与运维的关键工具。本文简要介绍Docker——一种开源容器化平台,能让应用程序及依赖项被打包成轻量级容器,在任何Linux或Windows机器上运行。文中涵盖Docker的安装步骤、基础命令操作如启动服务、查看版本、拉取与运行容器等。并通过实例演示了如何运行Nginx服务器和基于Dockerfile构建Python Flask应用镜像的过程。这些基础知识将助力网络工程师理解Docker的核心功能,并为实际应用提供指导。
54 2
|
1月前
|
安全 网络安全 网络架构
掌握traceroute:网络工程师解决路由问题的利器
【8月更文挑战第22天】`traceroute`是网络工程师的关键工具,用于追踪数据包从源到目的地的路径,帮助诊断网络问题并优化性能。通过向目标发送具有特定生存时间(TTL)值的数据包,`traceroute`能揭示每跳路由器的信息及延迟,便于识别瓶颈与故障。其基本用法为`traceroute [options] hostname/IP`。
62 1
|
1月前
|
Ubuntu Shell 网络架构
网络工程师的秘密武器:为何他们必须掌握Docker的基础知识?
【8月更文挑战第20天】在IT领域,Docker作为主流容器化平台,简化了应用部署与管理。网络工程师虽不必精通Docker,但需了解其基本概念如镜像、容器等,及如何创建、运行容器,还需掌握Docker网络模式如bridge、overlay等。这有助于与开发团队协作,设计高效网络架构。例如,通过`docker pull ubuntu`和`docker run -it ubuntu /bin/bash`即可拉取并启动Ubuntu容器。了解这些基础知识能促进跨团队沟通,适应快速发展的IT行业需求。
29 0
|
15天前
|
机器学习/深度学习
小土堆-pytorch-神经网络-损失函数与反向传播_笔记
在使用损失函数时,关键在于匹配输入和输出形状。例如,在L1Loss中,输入形状中的N代表批量大小。以下是具体示例:对于相同形状的输入和目标张量,L1Loss默认计算差值并求平均;此外,均方误差(MSE)也是常用损失函数。实战中,损失函数用于计算模型输出与真实标签间的差距,并通过反向传播更新模型参数。
|
27天前
|
图形学 C#
超实用!深度解析Unity引擎,手把手教你从零开始构建精美的2D平面冒险游戏,涵盖资源导入、角色控制与动画、碰撞检测等核心技巧,打造沉浸式游戏体验完全指南
【8月更文挑战第31天】本文是 Unity 2D 游戏开发的全面指南,手把手教你从零开始构建精美的平面冒险游戏。首先,通过 Unity Hub 创建 2D 项目并导入游戏资源。接着,编写 `PlayerController` 脚本来实现角色移动,并添加动画以增强视觉效果。最后,通过 Collider 2D 组件实现碰撞检测等游戏机制。每一步均展示 Unity 在 2D 游戏开发中的强大功能。
69 6
|
26天前
|
缓存 运维 监控
|
1月前
|
监控 安全 网络协议
【网络工程师必备神器】锐捷设备命令大全:一文在手,天下我有!
【8月更文挑战第22天】锐捷网络专攻网络解决方案,其设备广泛应用在教育、政府及企业等领域。本文汇总了锐捷设备常用命令及其应用场景:包括登录与退出设备、查看系统状态、接口与VLAN配置、路由与QoS设定、安全配置及日志监控等。通过示例如telnet/ssh登录、display命令查看信息、配置IP地址与VLAN、设置静态路由与OSPF、限速与队列调度、端口安全与ACL、SNMP监控与重启设备等,助力工程师高效管理与维护网络。
48 4
|
1月前
|
安全 网络安全 网络虚拟化
网络工程师必知的神秘术语大全究竟藏着哪些关键信息?快来一探究竟!
【8月更文挑战第22天】这份最新整理的网络技术中英文术语大全对于网络工程师来说是一份宝贵的资源。它可以帮助网络工程师更好地理解和掌握网络技术,提高工作效率,解决各种网络问题。无论是在网络规划、设计、实施还是维护阶段,这些术语都将发挥重要的作用。让我们一起收藏这份术语大全,为网络技术的学习和实践打下坚实的基础。
35 1