JAX 中文文档(三)(2)

简介: JAX 中文文档(三)

JAX 中文文档(三)(1)https://developer.aliyun.com/article/1559702


对 JAX 程序进行性能分析

原文:jax.readthedocs.io/en/latest/profiling.html

使用 Perfetto 查看程序跟踪

我们可以使用 JAX 分析器生成可以使用Perfetto 可视化工具查看的 JAX 程序的跟踪。目前,此方法会阻塞程序,直到点击链接并加载 Perfetto UI 以打开跟踪为止。如果您希望获取性能分析信息而无需任何交互,请查看下面的 Tensorboard 分析器。

with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  # Run the operations to be profiled
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready() 

计算完成后,程序会提示您打开链接到ui.perfetto.dev。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。

[外链图片转存中…(img-jkyqYGK7-1718950659302)]

加载链接后,程序执行将继续。链接在打开一次后将不再有效,但将重定向到一个保持有效的新 URL。然后,您可以在 Perfetto UI 中单击“共享”按钮,创建可与他人共享的跟踪的永久链接。

远程分析

在对远程运行的代码进行性能分析(例如在托管的虚拟机上)时,您需要在端口 9001 上建立 SSH 隧道以使链接工作。您可以使用以下命令执行此操作:

$  ssh  -L  9001:127.0.0.1:9001  <user>@<host> 

或者如果您正在使用 Google Cloud:

$  gcloud  compute  ssh  <machine-name>  --  -L  9001:127.0.0.1:9001 

手动捕获

而不是使用jax.profiler.trace以编程方式捕获跟踪,您可以通过在感兴趣的脚本中调用jax.profiler.start_server()来启动分析服务器。如果您只需在脚本的某部分保持分析服务器活动,则可以通过调用jax.profiler.stop_server()来关闭它。

脚本运行后并且分析服务器已启动后,我们可以通过运行以下命令手动捕获和跟踪:

$  python  -m  jax.collect_profile  <port>  <duration_in_ms> 

默认情况下,生成的跟踪信息会被转储到临时目录中,但可以通过传递--log_dir=<自定义目录>来覆盖此设置。另外,默认情况下,程序将提示您打开链接到ui.perfetto.dev。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。通过传递--no_perfetto_link命令可以禁用此功能。或者,您也可以将 Tensorboard 指向log_dir以分析跟踪(参见下面的“Tensorboard 分析”部分)。

TensorBoard 性能分析

TensorBoard 的分析器可用于分析 JAX 程序。Tensorboard 是获取和可视化程序性能跟踪和分析(包括 GPU 和 TPU 上的活动)的好方法。最终结果看起来类似于这样:


安装

TensorBoard 分析器仅与捆绑有 TensorFlow 的 TensorBoard 版本一起提供。

pip  install  tensorflow  tensorboard-plugin-profile 

如果您已安装了 TensorFlow,则只需安装tensorboard-plugin-profile pip 包。请注意仅安装一个版本的 TensorFlow 或 TensorBoard,否则可能会遇到下面描述的“重复插件”错误。有关安装 TensorBoard 的更多信息,请参见www.tensorflow.org/guide/profiler

程序化捕获

您可以通过jax.profiler.start_trace()jax.profiler.stop_trace()方法来配置您的代码以捕获性能分析器的追踪。调用start_trace()时需要指定写入追踪文件的目录。这个目录应该与启动 TensorBoard 时使用的--logdir目录相同。然后,您可以使用 TensorBoard 来查看这些追踪信息。

例如,要获取性能分析器的追踪:

import jax
jax.profiler.start_trace("/tmp/tensorboard")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace() 

注意block_until_ready()调用。我们使用这个函数来确保设备上的执行被追踪到。有关为什么需要这样做的详细信息,请参见异步调度部分。

您还可以使用jax.profiler.trace()上下文管理器作为start_tracestop_trace的替代方法:

import jax
with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready() 

要查看追踪信息,请首先启动 TensorBoard(如果尚未启动):

$  tensorboard  --logdir=/tmp/tensorboard
[...]
Serving  TensorBoard  on  localhost;  to  expose  to  the  network,  use  a  proxy  or  pass  --bind_all
TensorBoard  2.5.0  at  http://localhost:6006/  (Press  CTRL+C  to  quit) 

在这个示例中,您应该能够在localhost:6006/加载 TensorBoard。您可以使用--port标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。

然后,要么在右上角的下拉菜单中选择“Profile”,要么直接访问localhost:6006/#profile。可用的追踪信息会显示在左侧的“Runs”下拉菜单中。选择您感兴趣的运行,并在“Tools”下选择trace_viewer。现在您应该能看到执行时间轴。您可以使用 WASD 键来导航追踪信息,点击或拖动以选择事件并查看底部的更多详细信息。有关使用追踪查看器的更多详细信息,请参阅这些 TensorFlow 文档

您还可以使用memory_viewerop_profilegraph_viewer工具。

通过 TensorBoard 手动捕获

以下是从运行中的程序中手动触发 N 秒追踪的捕获说明。

  1. 启动 TensorBoard 服务器:
tensorboard  --logdir  /tmp/tensorboard/ 
  1. localhost:6006/处应该能够加载 TensorBoard。您可以使用--port标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。
  2. 在您希望进行分析的 Python 程序或进程中,将以下内容添加到开头的某个位置:
import jax.profiler
jax.profiler.start_server(9999) 
  1. 这将启动 TensorBoard 连接到的性能分析器服务器。在继续下一步之前,必须先运行性能分析器服务器。完成后,可以调用jax.profiler.stop_server()来关闭它。
    如果你想要分析一个长时间运行的程序片段(例如长时间的训练循环),你可以将此代码放在程序开头并像往常一样启动程序。如果你想要分析一个短程序(例如微基准测试),一种选择是在 IPython shell 中启动分析器服务器,并在下一步开始捕获后用 %run 运行短程序。另一种选择是在程序开头启动分析器服务器,并使用 time.sleep() 给你足够的时间启动捕获。
  2. 打开localhost:6006/#profile,并点击左上角的“CAPTURE PROFILE”按钮。将“localhost:9999”作为分析服务的 URL(这是你在上一步中启动的分析器服务器的地址)。输入你想要进行分析的毫秒数,然后点击“CAPTURE”。
  3. 如果你想要分析的代码尚未运行(例如在 Python shell 中启动了分析器服务器),请在进行捕获时运行它。
  4. 捕获完成后,TensorBoard 应会自动刷新。(并非所有 TensorBoard 分析功能都与 JAX 连接,所以初始时看起来可能没有捕获到任何内容。)在左侧的“工具”下,选择 trace_viewer
    现在你应该可以看到执行的时间轴。你可以使用 WASD 键来导航跟踪,点击或拖动选择事件以在底部查看更多详细信息。参见这些 TensorFlow 文档获取有关使用跟踪查看器的更多详细信息。
    你也可以使用 memory_viewerop_profilegraph_viewer 工具。

添加自定义跟踪事件

默认情况下,跟踪查看器中的事件大多是低级内部 JAX 函数。你可以使用 jax.profiler.TraceAnnotationjax.profiler.annotate_function() 在你的代码中添加自定义事件和函数。

故障排除

GPU 分析

运行在 GPU 上的程序应该在跟踪查看器顶部附近生成 GPU 流的跟踪。如果只看到主机跟踪,请检查程序日志和/或输出,查看以下错误消息。

如果出现类似 Could not load dynamic library 'libcupti.so.10.1' 的错误

完整错误:

W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory
2020-06-12 13:19:59.822799: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found. 

libcupti.so的路径添加到环境变量LD_LIBRARY_PATH中。(尝试使用locate libcupti.so来找到路径。)例如:

export  LD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH 

即使在做了以上步骤后仍然收到 Could not load dynamic library 错误消息,请检查 GPU 跟踪是否仍然显示在跟踪查看器中。有时即使一切正常,它也会出现此消息,因为它在多个位置查找 libcupti 库。

如果出现类似 failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES 的错误

完整错误:

E  external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445]  function  cupti_interface_->EnableCallback(  0  ,  subscriber_,  CUPTI_CB_DOMAIN_DRIVER_API,  cbid)failed  with  error  CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
2020-06-12  14:31:54.097791:  E  external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487]  function  cupti_interface_->ActivityDisable(activity)failed  with  error  CUPTI_ERROR_NOT_INITIALIZED 

运行以下命令(注意这将需要重新启动):

echo  'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"'  |  sudo  tee  -a  /etc/modprobe.d/nvidia-kernel-common.conf
sudo  update-initramfs  -u
sudo  reboot  now 

查看更多关于此错误的信息,请参阅NVIDIA 的文档

在远程机器上进行性能分析

如果要分析的 JAX 程序正在远程机器上运行,一种选择是在远程机器上执行上述所有说明(特别是在远程机器上启动 TensorBoard  服务器),然后使用 SSH 本地端口转发从本地访问 TensorBoard Web UI。使用以下 SSH 命令将默认的 TensorBoard  端口 6006 从本地转发到远程机器:

ssh  -L  6006:localhost:6006  <remote  server  address> 

或者如果您正在使用 Google Cloud:

$  gcloud  compute  ssh  <machine-name>  --  -L  6006:localhost:6006 
```#### 多个 TensorBoard 安装
**如果启动 TensorBoard 失败,并出现类似于`ValueError: Duplicate plugins for name projector`的错误**
这通常是因为安装了两个版本的 TensorBoard 和/或 TensorFlow(例如,`tensorflow`、`tf-nightly`、`tensorboard`和`tb-nightly` pip 包都包含 TensorBoard)。卸载一个 pip 包可能会导致`tensorboard`可执行文件被移除,难以替换,因此可能需要卸载所有内容并重新安装单个版本:
```py
pip  uninstall  tensorflow  tf-nightly  tensorboard  tb-nightly
pip  install  tensorflow 

Nsight

NVIDIA 的Nsight工具可用于跟踪和分析 GPU 上的 JAX 代码。有关详情,请参阅Nsight文档

设备内存分析

原文:jax.readthedocs.io/en/latest/device_memory_profiling.html

注意

2023 年 5 月更新:我们建议使用 Tensorboard 进行设备内存分析。在进行分析后,打开 Tensorboard 分析器的 memory_viewer 标签以获取更详细和易于理解的设备内存使用情况。

JAX 设备内存分析器允许我们探索 JAX 程序如何以及为何使用 GPU 或 TPU 内存。例如,它可用于:

  • 查明在特定时间点哪些数组和可执行文件位于 GPU 内存中,或者
  • 追踪内存泄漏。

安装

JAX 设备内存分析器生成的输出可使用 pprof (google/pprof) 解释。首先按照其 安装说明 安装 pprof。撰写时,安装 pprof 需要先安装版本为 1.16+ 的 GoGraphviz,然后运行

go  install  github.com/google/pprof@latest 

安装 pprof 作为 $GOPATH/bin/pprof,其中 GOPATH 默认为 ~/go

注意

来自 google/pprofpprof 版本与作为 gperftools 软件包一部分分发的同名旧工具不同。gperftools 版本的 pprof 不适用于 JAX。

理解 JAX 程序如何使用 GPU 或 TPU 内存

设备内存分析器的常见用途是找出为何 JAX 程序使用大量 GPU 或 TPU 内存,例如调试内存不足问题。

要将设备内存分析保存到磁盘,使用 jax.profiler.save_device_memory_profile()。例如,考虑以下 Python 程序:

import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
  return jnp.tile(x, 10) * 0.5
def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof") 

如果我们首先运行上述程序,然后执行

pprof  --web  memory.prof 

pprof 打开一个包含设备内存分析调用图格式的 Web 浏览器:

[外链图片转存中…(img-NBNUIEpy-1718950659302)]

调用图是在每个活动缓冲区分配的 Python 栈的可视化。例如,在这个特定情况下,可视化显示 func2 及其被调用者负责分配了 76.30MB,其中 38.15MB 是在从 func1func2 的调用中分配的。有关如何解释调用图可视化的更多信息,请参阅 pprof 文档

使用 jax.jit() 编译的函数对设备内存分析器不透明。也就是说,任何在 jit 编译函数内部分配的内存都将归因于整个函数。

在本例中,调用 block_until_ready() 是为了确保在收集设备内存分析之前 func2 完成。有关更多详细信息,请参阅异步调度。

调试内存泄漏

我们还可以使用 JAX 设备内存分析器,通过使用 pprof 来可视化在不同时间点获取的两个设备内存配置文件中的内存使用情况变化,以追踪内存泄漏。例如,考虑以下程序,该程序将 JAX 数组累积到一个不断增长的 Python 列表中。

import jax
import jax.numpy as jnp
import jax.profiler
def afunction():
  return jax.random.normal(jax.random.key(77), (1000000,))
z = afunction()
def anotherfunc():
  arrays = []
  for i in range(1, 10):
    x = jax.random.normal(jax.random.key(42), (i, 10000))
    arrays.append(x)
    x.block_until_ready()
    jax.profiler.save_device_memory_profile(f"memory{i}.prof")
anotherfunc() 

如果我们仅在执行结束时可视化设备内存配置文件(memory9.prof),则可能不明显,即 anotherfunc 中的每次循环迭代都会累积更多的设备内存分配:

pprof  --web  memory9.prof 

[外链图片转存中…(img-UJpftsmN-1718950659303)]

afunction 内部的大型但固定分配主导配置文件,但不会随时间增长。

通过使用 pprof--diff_base 功能 来可视化循环迭代中内存使用情况的变化,我们可以找出程序内存使用量随时间增加的原因:

pprof  --web  --diff_base  memory1.prof  memory9.prof 


可视化显示,内存增长可以归因于 anotherfunc 中对 normal 的调用。

在 JAX 中进行运行时值调试

原文:jax.readthedocs.io/en/latest/debugging/index.html

是否遇到梯度爆炸?NaN 使你牙齿咬紧?只想查看计算中间值?请查看以下 JAX 调试工具!本页提供了 TL;DR 摘要,并且您可以点击底部的“阅读更多”链接了解更多信息。

目录:

  • 使用 jax.debug 进行交互式检查
  • 使用 jax.experimental.checkify 进行功能错误检查
  • 使用 JAX 的调试标志抛出 Python 错误

使用 jax.debug 进行交互式检查

TL;DR 使用 jax.debug.print()jax.jitjax.pmappjit 装饰的函数中将值打印到 stdout,并使用 jax.debug.breakpoint() 暂停执行编译函数以检查调用堆栈中的值:

import jax
import jax.numpy as jnp
@jax.jit
def f(x):
  jax.debug.print("🤯 {x} 🤯", x=x)
  y = jnp.sin(x)
  jax.debug.breakpoint()
  jax.debug.print("🤯 {y} 🤯", y=y)
  return y
f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯 

点击此处了解更多!

使用 jax.experimental.checkify 进行功能错误检查

TL;DR Checkify 允许您向 JAX 代码添加 jit 可用的运行时错误检查(例如越界索引)。使用 checkify.checkify 转换以及类似断言的 checkify.check 函数,向 JAX 代码添加运行时检查:

from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
  checkify.check(i >= 0, "index needs to be non-negative!")
  y = x[i]
  z = jnp.sin(y)
  return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f)) 

您还可以使用 checkify 自动添加常见检查:

errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f) 

点击此处了解更多!

使用 JAX 的调试标志抛出 Python 错误

TL;DR 启用 jax_debug_nans 标志,自动检测在 jax.jit 编译的代码中生成 NaN 时(但不在 jax.pmapjax.pjit 编译的代码中),并启用 jax_disable_jit 标志以禁用 JIT 编译,从而使用传统的 Python 调试工具如 printpdb

import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception! 

点击此处了解更多!

阅读更多

  • jax.debug.printjax.debug.breakpoint
  • checkify 转换
  • JAX 调试标志


JAX 中文文档(三)(3)https://developer.aliyun.com/article/1559705

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
4月前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
40 1
|
4月前
|
机器学习/深度学习 存储 移动开发
JAX 中文文档(八)(1)
JAX 中文文档(八)
29 1
|
4月前
|
机器学习/深度学习 索引 Python
JAX 中文文档(四)(1)
JAX 中文文档(四)
42 0
|
4月前
|
C++ 索引 Python
JAX 中文文档(九)(2)
JAX 中文文档(九)
24 0
|
4月前
|
存储 移动开发 Python
JAX 中文文档(八)(2)
JAX 中文文档(八)
28 0
|
4月前
|
测试技术 TensorFlow 算法框架/工具
JAX 中文文档(五)(2)
JAX 中文文档(五)
56 0
|
4月前
|
存储 机器学习/深度学习 编译器
JAX 中文文档(九)(1)
JAX 中文文档(九)
50 0
|
4月前
|
机器学习/深度学习 API 索引
JAX 中文文档(二)(2)
JAX 中文文档(二)
34 0
|
4月前
|
机器学习/深度学习 算法 异构计算
JAX 中文文档(七)(2)
JAX 中文文档(七)
29 0
|
4月前
|
存储 编译器 芯片
JAX 中文文档(五)(5)
JAX 中文文档(五)
41 0