JAX 中文文档(三)(1)https://developer.aliyun.com/article/1559702
对 JAX 程序进行性能分析
使用 Perfetto 查看程序跟踪
我们可以使用 JAX 分析器生成可以使用Perfetto 可视化工具查看的 JAX 程序的跟踪。目前,此方法会阻塞程序,直到点击链接并加载 Perfetto UI 以打开跟踪为止。如果您希望获取性能分析信息而无需任何交互,请查看下面的 Tensorboard 分析器。
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True): # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready()
计算完成后,程序会提示您打开链接到ui.perfetto.dev
。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。
[外链图片转存中…(img-jkyqYGK7-1718950659302)]
加载链接后,程序执行将继续。链接在打开一次后将不再有效,但将重定向到一个保持有效的新 URL。然后,您可以在 Perfetto UI 中单击“共享”按钮,创建可与他人共享的跟踪的永久链接。
远程分析
在对远程运行的代码进行性能分析(例如在托管的虚拟机上)时,您需要在端口 9001 上建立 SSH 隧道以使链接工作。您可以使用以下命令执行此操作:
$ ssh -L 9001:127.0.0.1:9001 <user>@<host>
或者如果您正在使用 Google Cloud:
$ gcloud compute ssh <machine-name> -- -L 9001:127.0.0.1:9001
手动捕获
而不是使用jax.profiler.trace
以编程方式捕获跟踪,您可以通过在感兴趣的脚本中调用jax.profiler.start_server()
来启动分析服务器。如果您只需在脚本的某部分保持分析服务器活动,则可以通过调用jax.profiler.stop_server()
来关闭它。
脚本运行后并且分析服务器已启动后,我们可以通过运行以下命令手动捕获和跟踪:
$ python -m jax.collect_profile <port> <duration_in_ms>
默认情况下,生成的跟踪信息会被转储到临时目录中,但可以通过传递--log_dir=<自定义目录>
来覆盖此设置。另外,默认情况下,程序将提示您打开链接到ui.perfetto.dev
。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。通过传递--no_perfetto_link
命令可以禁用此功能。或者,您也可以将 Tensorboard 指向log_dir
以分析跟踪(参见下面的“Tensorboard 分析”部分)。
TensorBoard 性能分析
TensorBoard 的分析器可用于分析 JAX 程序。Tensorboard 是获取和可视化程序性能跟踪和分析(包括 GPU 和 TPU 上的活动)的好方法。最终结果看起来类似于这样:
安装
TensorBoard 分析器仅与捆绑有 TensorFlow 的 TensorBoard 版本一起提供。
pip install tensorflow tensorboard-plugin-profile
如果您已安装了 TensorFlow,则只需安装tensorboard-plugin-profile
pip 包。请注意仅安装一个版本的 TensorFlow 或 TensorBoard,否则可能会遇到下面描述的“重复插件”错误。有关安装 TensorBoard 的更多信息,请参见www.tensorflow.org/guide/profiler
。
程序化捕获
您可以通过jax.profiler.start_trace()
和jax.profiler.stop_trace()
方法来配置您的代码以捕获性能分析器的追踪。调用start_trace()
时需要指定写入追踪文件的目录。这个目录应该与启动 TensorBoard 时使用的--logdir
目录相同。然后,您可以使用 TensorBoard 来查看这些追踪信息。
例如,要获取性能分析器的追踪:
import jax jax.profiler.start_trace("/tmp/tensorboard") # Run the operations to be profiled key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready() jax.profiler.stop_trace()
注意block_until_ready()
调用。我们使用这个函数来确保设备上的执行被追踪到。有关为什么需要这样做的详细信息,请参见异步调度部分。
您还可以使用jax.profiler.trace()
上下文管理器作为start_trace
和stop_trace
的替代方法:
import jax with jax.profiler.trace("/tmp/tensorboard"): key = jax.random.key(0) x = jax.random.normal(key, (5000, 5000)) y = x @ x y.block_until_ready()
要查看追踪信息,请首先启动 TensorBoard(如果尚未启动):
$ tensorboard --logdir=/tmp/tensorboard [...] Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all TensorBoard 2.5.0 at http://localhost:6006/ (Press CTRL+C to quit)
在这个示例中,您应该能够在localhost:6006/
加载 TensorBoard。您可以使用--port
标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。
然后,要么在右上角的下拉菜单中选择“Profile”,要么直接访问localhost:6006/#profile
。可用的追踪信息会显示在左侧的“Runs”下拉菜单中。选择您感兴趣的运行,并在“Tools”下选择trace_viewer
。现在您应该能看到执行时间轴。您可以使用 WASD 键来导航追踪信息,点击或拖动以选择事件并查看底部的更多详细信息。有关使用追踪查看器的更多详细信息,请参阅这些 TensorFlow 文档。
您还可以使用memory_viewer
、op_profile
和graph_viewer
工具。
通过 TensorBoard 手动捕获
以下是从运行中的程序中手动触发 N 秒追踪的捕获说明。
- 启动 TensorBoard 服务器:
tensorboard --logdir /tmp/tensorboard/
- 在
localhost:6006/
处应该能够加载 TensorBoard。您可以使用--port
标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。 - 在您希望进行分析的 Python 程序或进程中,将以下内容添加到开头的某个位置:
import jax.profiler jax.profiler.start_server(9999)
- 这将启动 TensorBoard 连接到的性能分析器服务器。在继续下一步之前,必须先运行性能分析器服务器。完成后,可以调用
jax.profiler.stop_server()
来关闭它。
如果你想要分析一个长时间运行的程序片段(例如长时间的训练循环),你可以将此代码放在程序开头并像往常一样启动程序。如果你想要分析一个短程序(例如微基准测试),一种选择是在 IPython shell 中启动分析器服务器,并在下一步开始捕获后用%run
运行短程序。另一种选择是在程序开头启动分析器服务器,并使用time.sleep()
给你足够的时间启动捕获。 - 打开
localhost:6006/#profile
,并点击左上角的“CAPTURE PROFILE”按钮。将“localhost:9999”作为分析服务的 URL(这是你在上一步中启动的分析器服务器的地址)。输入你想要进行分析的毫秒数,然后点击“CAPTURE”。 - 如果你想要分析的代码尚未运行(例如在 Python shell 中启动了分析器服务器),请在进行捕获时运行它。
- 捕获完成后,TensorBoard 应会自动刷新。(并非所有 TensorBoard 分析功能都与 JAX 连接,所以初始时看起来可能没有捕获到任何内容。)在左侧的“工具”下,选择
trace_viewer
。
现在你应该可以看到执行的时间轴。你可以使用 WASD 键来导航跟踪,点击或拖动选择事件以在底部查看更多详细信息。参见这些 TensorFlow 文档获取有关使用跟踪查看器的更多详细信息。
你也可以使用memory_viewer
、op_profile
和graph_viewer
工具。
添加自定义跟踪事件
默认情况下,跟踪查看器中的事件大多是低级内部 JAX 函数。你可以使用 jax.profiler.TraceAnnotation
和 jax.profiler.annotate_function()
在你的代码中添加自定义事件和函数。
故障排除
GPU 分析
运行在 GPU 上的程序应该在跟踪查看器顶部附近生成 GPU 流的跟踪。如果只看到主机跟踪,请检查程序日志和/或输出,查看以下错误消息。
如果出现类似 Could not load dynamic library 'libcupti.so.10.1'
的错误
完整错误:
W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory 2020-06-12 13:19:59.822799: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found.
将libcupti.so
的路径添加到环境变量LD_LIBRARY_PATH
中。(尝试使用locate libcupti.so
来找到路径。)例如:
export LD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH
即使在做了以上步骤后仍然收到 Could not load dynamic library
错误消息,请检查 GPU 跟踪是否仍然显示在跟踪查看器中。有时即使一切正常,它也会出现此消息,因为它在多个位置查找 libcupti
库。
如果出现类似 failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
的错误
完整错误:
E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445] function cupti_interface_->EnableCallback( 0 , subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES 2020-06-12 14:31:54.097791: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487] function cupti_interface_->ActivityDisable(activity)failed with error CUPTI_ERROR_NOT_INITIALIZED
运行以下命令(注意这将需要重新启动):
echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"' | sudo tee -a /etc/modprobe.d/nvidia-kernel-common.conf sudo update-initramfs -u sudo reboot now
查看更多关于此错误的信息,请参阅NVIDIA 的文档。
在远程机器上进行性能分析
如果要分析的 JAX 程序正在远程机器上运行,一种选择是在远程机器上执行上述所有说明(特别是在远程机器上启动 TensorBoard 服务器),然后使用 SSH 本地端口转发从本地访问 TensorBoard Web UI。使用以下 SSH 命令将默认的 TensorBoard 端口 6006 从本地转发到远程机器:
ssh -L 6006:localhost:6006 <remote server address>
或者如果您正在使用 Google Cloud:
$ gcloud compute ssh <machine-name> -- -L 6006:localhost:6006 ```#### 多个 TensorBoard 安装 **如果启动 TensorBoard 失败,并出现类似于`ValueError: Duplicate plugins for name projector`的错误** 这通常是因为安装了两个版本的 TensorBoard 和/或 TensorFlow(例如,`tensorflow`、`tf-nightly`、`tensorboard`和`tb-nightly` pip 包都包含 TensorBoard)。卸载一个 pip 包可能会导致`tensorboard`可执行文件被移除,难以替换,因此可能需要卸载所有内容并重新安装单个版本: ```py pip uninstall tensorflow tf-nightly tensorboard tb-nightly pip install tensorflow
Nsight
NVIDIA 的Nsight
工具可用于跟踪和分析 GPU 上的 JAX 代码。有关详情,请参阅Nsight
文档。
设备内存分析
原文:
jax.readthedocs.io/en/latest/device_memory_profiling.html
注意
2023 年 5 月更新:我们建议使用 Tensorboard 进行设备内存分析。在进行分析后,打开 Tensorboard 分析器的 memory_viewer
标签以获取更详细和易于理解的设备内存使用情况。
JAX 设备内存分析器允许我们探索 JAX 程序如何以及为何使用 GPU 或 TPU 内存。例如,它可用于:
- 查明在特定时间点哪些数组和可执行文件位于 GPU 内存中,或者
- 追踪内存泄漏。
安装
JAX 设备内存分析器生成的输出可使用 pprof (google/pprof) 解释。首先按照其 安装说明 安装 pprof
。撰写时,安装 pprof
需要先安装版本为 1.16+ 的 Go,Graphviz,然后运行
go install github.com/google/pprof@latest
安装 pprof
作为 $GOPATH/bin/pprof
,其中 GOPATH
默认为 ~/go
。
注意
来自 google/pprof 的 pprof
版本与作为 gperftools
软件包一部分分发的同名旧工具不同。gperftools
版本的 pprof
不适用于 JAX。
理解 JAX 程序如何使用 GPU 或 TPU 内存
设备内存分析器的常见用途是找出为何 JAX 程序使用大量 GPU 或 TPU 内存,例如调试内存不足问题。
要将设备内存分析保存到磁盘,使用 jax.profiler.save_device_memory_profile()
。例如,考虑以下 Python 程序:
import jax import jax.numpy as jnp import jax.profiler def func1(x): return jnp.tile(x, 10) * 0.5 def func2(x): y = func1(x) return y, jnp.tile(x, 10) + 1 x = jax.random.normal(jax.random.key(42), (1000, 1000)) y, z = func2(x) z.block_until_ready() jax.profiler.save_device_memory_profile("memory.prof")
如果我们首先运行上述程序,然后执行
pprof --web memory.prof
pprof
打开一个包含设备内存分析调用图格式的 Web 浏览器:
[外链图片转存中…(img-NBNUIEpy-1718950659302)]
调用图是在每个活动缓冲区分配的 Python 栈的可视化。例如,在这个特定情况下,可视化显示 func2
及其被调用者负责分配了 76.30MB,其中 38.15MB 是在从 func1
到 func2
的调用中分配的。有关如何解释调用图可视化的更多信息,请参阅 pprof 文档。
使用 jax.jit()
编译的函数对设备内存分析器不透明。也就是说,任何在 jit
编译函数内部分配的内存都将归因于整个函数。
在本例中,调用 block_until_ready()
是为了确保在收集设备内存分析之前 func2
完成。有关更多详细信息,请参阅异步调度。
调试内存泄漏
我们还可以使用 JAX 设备内存分析器,通过使用 pprof
来可视化在不同时间点获取的两个设备内存配置文件中的内存使用情况变化,以追踪内存泄漏。例如,考虑以下程序,该程序将 JAX 数组累积到一个不断增长的 Python 列表中。
import jax import jax.numpy as jnp import jax.profiler def afunction(): return jax.random.normal(jax.random.key(77), (1000000,)) z = afunction() def anotherfunc(): arrays = [] for i in range(1, 10): x = jax.random.normal(jax.random.key(42), (i, 10000)) arrays.append(x) x.block_until_ready() jax.profiler.save_device_memory_profile(f"memory{i}.prof") anotherfunc()
如果我们仅在执行结束时可视化设备内存配置文件(memory9.prof
),则可能不明显,即 anotherfunc
中的每次循环迭代都会累积更多的设备内存分配:
pprof --web memory9.prof
[外链图片转存中…(img-UJpftsmN-1718950659303)]
在 afunction
内部的大型但固定分配主导配置文件,但不会随时间增长。
通过使用 pprof
的 --diff_base
功能 来可视化循环迭代中内存使用情况的变化,我们可以找出程序内存使用量随时间增加的原因:
pprof --web --diff_base memory1.prof memory9.prof
可视化显示,内存增长可以归因于 anotherfunc
中对 normal
的调用。
在 JAX 中进行运行时值调试
是否遇到梯度爆炸?NaN 使你牙齿咬紧?只想查看计算中间值?请查看以下 JAX 调试工具!本页提供了 TL;DR 摘要,并且您可以点击底部的“阅读更多”链接了解更多信息。
目录:
- 使用
jax.debug
进行交互式检查 - 使用 jax.experimental.checkify 进行功能错误检查
- 使用 JAX 的调试标志抛出 Python 错误
使用 jax.debug
进行交互式检查
TL;DR 使用 jax.debug.print()
在 jax.jit
、jax.pmap
和 pjit
装饰的函数中将值打印到 stdout,并使用 jax.debug.breakpoint()
暂停执行编译函数以检查调用堆栈中的值:
import jax import jax.numpy as jnp @jax.jit def f(x): jax.debug.print("🤯 {x} 🤯", x=x) y = jnp.sin(x) jax.debug.breakpoint() jax.debug.print("🤯 {y} 🤯", y=y) return y f(2.) # Prints: # 🤯 2.0 🤯 # Enters breakpoint to inspect values! # 🤯 0.9092974662780762 🤯
点击此处了解更多!
使用 jax.experimental.checkify
进行功能错误检查
TL;DR Checkify 允许您向 JAX 代码添加 jit
可用的运行时错误检查(例如越界索引)。使用 checkify.checkify
转换以及类似断言的 checkify.check
函数,向 JAX 代码添加运行时检查:
from jax.experimental import checkify import jax import jax.numpy as jnp def f(x, i): checkify.check(i >= 0, "index needs to be non-negative!") y = x[i] z = jnp.sin(y) return z jittable_f = checkify.checkify(f) err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1) print(err.get()) # >> index needs to be non-negative! (check failed at <...>:6 (f))
您还可以使用 checkify 自动添加常见检查:
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks checked_f = checkify.checkify(f, errors=errors) err, z = checked_f(jnp.ones((5,)), 100) err.throw() # ValueError: out-of-bounds indexing at <..>:7 (f) err, z = checked_f(jnp.ones((5,)), -1) err.throw() # ValueError: index needs to be non-negative! (check failed at <…>:6 (f)) err, z = checked_f(jnp.array([jnp.inf, 1]), 0) err.throw() # ValueError: nan generated by primitive sin at <...>:8 (f)
点击此处了解更多!
使用 JAX 的调试标志抛出 Python 错误
TL;DR 启用 jax_debug_nans
标志,自动检测在 jax.jit
编译的代码中生成 NaN 时(但不在 jax.pmap
或 jax.pjit
编译的代码中),并启用 jax_disable_jit
标志以禁用 JIT 编译,从而使用传统的 Python 调试工具如 print
和 pdb
。
import jax jax.config.update("jax_debug_nans", True) def f(x, y): return x / y jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
点击此处了解更多!
阅读更多
jax.debug.print
和jax.debug.breakpoint
checkify
转换- JAX 调试标志
JAX 中文文档(三)(3)https://developer.aliyun.com/article/1559705