JAX 中文文档(三)(5)

简介: JAX 中文文档(三)

JAX 中文文档(三)(4)https://developer.aliyun.com/article/1559705


JAX 调试标志

原文:jax.readthedocs.io/en/latest/debugging/flags.html

JAX 提供了标志和上下文管理器,可更轻松地捕获错误。

jax_debug_nans 配置选项和上下文管理器

简而言之 启用 jax_debug_nans 标志可自动检测在 jax.jit 编译的代码中产生 NaN(但不适用于 jax.pmapjax.pjit 编译的代码)。

jax_debug_nans 是一个 JAX 标志,当启用时,会在检测到 NaN 时自动引发错误。它对 JIT 编译有特殊处理——如果从 JIT 编译函数检测到 NaN 输出,函数会急切地重新运行(即不经过编译),并在产生 NaN 的具体原始基元处引发错误。

用法

如果您想追踪函数或梯度中出现 NaN 的位置,可以通过以下方式打开 NaN 检查器:

  • 设置 JAX_DEBUG_NANS=True 环境变量;
  • 在主文件顶部附近添加 jax.config.update("jax_debug_nans", True)
  • 在主文件添加 jax.config.parse_flags_with_absl(),然后像 --jax_debug_nans=True 这样使用命令行标志设置选项;

示例

import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception! 
jax_debug_nans 的优势和限制
优势
  • 易于应用
  • 精确检测产生 NaN 的位置
  • 抛出标准的 Python 异常,与 PDB 事后调试兼容
限制
  • jax.pmapjax.pjit 不兼容
  • 急切重新运行函数可能会很慢
  • 误报阳性(例如故意创建 NaN)

jax_disable_jit 配置选项和上下文管理器

简而言之 启用 jax_disable_jit 标志可禁用 JIT 编译,从而启用传统的 Python 调试工具如 printpdb

jax_disable_jit 是一个 JAX 标志,当启用时,会在整个 JAX 中禁用 JIT 编译(包括在控制流函数如 jax.lax.condjax.lax.scan 中)。

用法

您可以通过以下方式禁用 JIT 编译:

  • 设置 JAX_DISABLE_JIT=True 环境变量;
  • 在主文件顶部附近添加 jax.config.update("jax_disable_jit", True)
  • 在主文件添加 jax.config.parse_flags_with_absl(),然后像 --jax_disable_jit=True 这样使用命令行标志设置选项;

示例

import jax
jax.config.update("jax_disable_jit", True)
def f(x):
  y = jnp.log(x)
  if jnp.isnan(y):
    breakpoint()
  return y
jax.jit(f)(-2.)  # ==> Enters PDB breakpoint! 
jax_disable_jit 的优势和限制
优势
  • 易于应用
  • 启用 Python 内置的 breakpointprint
  • 抛出标准的 Python 异常,与 PDB 事后调试兼容
限制
  • jax.pmapjax.pjit 不兼容
  • 在没有 JIT 编译的情况下运行函数可能会很慢

GPU 性能提示

原文:jax.readthedocs.io/en/latest/gpu_performance_tips.html

本文档专注于神经网络工作负载的性能提示。

矩阵乘法精度

在像 Nvidia A100 一代或更高的最新 GPU 代中,将大多数计算以 bfloat16 精度执行可能是个好主意。例如,如果使用 Flax,可以使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例:

XLA 性能标志

注意

JAX-Toolbox 还有一个关于 NVIDIA XLA 性能 FLAGS 的页面。

XLA 标志的存在和确切行为可能取决于 jaxlib 版本。

截至 jaxlib==0.4.18(发布于 2023 年 10 月 6 日),设置这些 XLA 标志可以提高性能。其中一些与多 GPU 之间的通信相关,因此仅在多设备运行计算时才相关,而其他一些与每个设备上的代码生成相关。

未来版本中可能会默认设置其中一些。

这些标志可以通过 XLA_FLAGS shell 环境变量进行设置。例如,我们可以将其添加到 Python 文件的顶部:

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
) 

更多示例,请参阅 XLA Flags recommended for Pax training on Nvidia GPUs

代码生成标志

  • –xla_gpu_enable_triton_softmax_fusion 此标志启用基于 Triton 代码生成支持的模式匹配自动 softmax 融合。默认值为 False。
  • –xla_gpu_triton_gemm_any 使用基于 Triton 的 GEMM(矩阵乘法)发射器支持的任何 GEMM。默认值为 False。

通信标志

  • –xla_gpu_enable_async_collectives 此标志启用诸如AllReduceAllGatherReduceScatterCollectivePermute等集体操作以异步方式进行。异步通信可以将跨核心通信与计算重叠。默认值为 False。
  • –xla_gpu_enable_latency_hiding_scheduler 这个标志启用了延迟隐藏调度器,可以高效地将异步通信与计算重叠。默认值为 False。
  • –xla_gpu_enable_pipelined_collectives 在使用管道并行时,此标志允许将(i+1)层权重的AllGather与第 i 层的计算重叠。它还允许将(i+1)层权重的Reduce/ReduceScatter与第 i 层的计算重叠。默认值为 False。在启用此标志时存在一些错误。
  • –xla_gpu_collective_permute_decomposer_threshold 当执行GSPMD pipelining时,这个标志非常有用。设置一个非零的阈值会将CollectivePermute分解为CollectivePermuteReceiveDoneCollectivePermuteSendDone对,从而可以在每个对应的ReceiveDone/SendDone对之间执行计算,从而实现更多的重叠。默认阈值为 0,不进行分解。将其设置为大于 0 的阈值,例如--xla_gpu_collective_permute_decomposer_threshold=1024,可以启用此功能。
  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志用于调整何时将多个小的AllGather/ReduceScatter/AllReduce组合成一个大的AllGather/ReduceScatter/AllReduce,以减少跨设备通信所花费的时间。例如,在基于 Transformer 的工作负载上,可以考虑将AllGather/ReduceScatter阈值调高,以至少组合一个 Transformer 层的权重AllGather/ReduceScatter。默认情况下,combine_threshold_bytes设置为 256。

NCCL 标志

这些 Nvidia NCCL 标志值可能对在 Nvidia GPU 上进行单主机多设备计算有用:

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 }) 

这些 NCCL 标志可以提高单主机通信速度。然而,这些标志对多主机通信似乎不太有用。

多进程

我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 jitted 计算。当在 SLURM 下运行时,jax.distributed.initialize() API 将自动理解此配置。然而,这只是一个经验法则,可能有必要在您的用例中测试每个 GPU 一个进程和每个节点一个进程的情况。

持久编译缓存

原文:jax.readthedocs.io/en/latest/persistent_compilation_cache.html

JAX 具有可选的磁盘缓存用于编译程序。如果启用,JAX 将在磁盘上存储编译程序的副本,这在重复运行相同或类似任务时可以节省重新编译时间。

使用

当设置了cache-location时,编译缓存将启用。这应在第一次编译之前完成。设置位置如下:

import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location") 

有关cache-location的更多详细信息,请参见以下各节。

set_cache_dir()是设置cache-location的另一种方法。

本地文件系统

cache-location可以是本地文件系统上的目录。例如:

import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache") 

注意:缓存没有实现驱逐机制。如果cache-location是本地文件系统中的目录,则其大小将继续增长,除非手动删除文件。

Google Cloud

在 Google Cloud 上运行时,可以将编译缓存放置在 Google Cloud Storage(GCS)存储桶中。我们建议采用以下配置:

  • 在与工作负载运行地区相同的地方创建存储桶。
  • 在与工作负载的 VM 相同的项目中创建存储桶。确保设置了权限,使 VM 能够向存储桶写入。
  • 对于较小的工作负载,不需要复制。较大的工作负载可能会受益于复制。
  • 对于存储桶的默认存储类别,请使用“标准”。
  • 将软删除策略设置为最短期限:7 天。
  • 将对象生命周期设置为预期的工作负载运行时间。例如,如果工作负载预计运行 10 天,则将对象生命周期设置为 10 天。这应该涵盖整个运行期间发生的重启。使用age作为生命周期条件,使用Delete作为操作。详情请参见对象生命周期管理。如果未设置对象生命周期,则缓存将继续增长,因为没有实现驱逐机制。
  • 所有加密策略都受支持。

假设gs://jax-cache是 GCS 存储桶,请设置如下cache-location

import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache") 

tter。默认情况下,combine_threshold_bytes`设置为 256。

NCCL 标志

这些 Nvidia NCCL 标志值可能对在 Nvidia GPU 上进行单主机多设备计算有用:

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 }) 

这些 NCCL 标志可以提高单主机通信速度。然而,这些标志对多主机通信似乎不太有用。

多进程

我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 jitted 计算。当在 SLURM 下运行时,jax.distributed.initialize() API 将自动理解此配置。然而,这只是一个经验法则,可能有必要在您的用例中测试每个 GPU 一个进程和每个节点一个进程的情况。

持久编译缓存

原文:jax.readthedocs.io/en/latest/persistent_compilation_cache.html

JAX 具有可选的磁盘缓存用于编译程序。如果启用,JAX 将在磁盘上存储编译程序的副本,这在重复运行相同或类似任务时可以节省重新编译时间。

使用

当设置了cache-location时,编译缓存将启用。这应在第一次编译之前完成。设置位置如下:

import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location") 

有关cache-location的更多详细信息,请参见以下各节。

set_cache_dir()是设置cache-location的另一种方法。

本地文件系统

cache-location可以是本地文件系统上的目录。例如:

import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache") 

注意:缓存没有实现驱逐机制。如果cache-location是本地文件系统中的目录,则其大小将继续增长,除非手动删除文件。

Google Cloud

在 Google Cloud 上运行时,可以将编译缓存放置在 Google Cloud Storage(GCS)存储桶中。我们建议采用以下配置:

  • 在与工作负载运行地区相同的地方创建存储桶。
  • 在与工作负载的 VM 相同的项目中创建存储桶。确保设置了权限,使 VM 能够向存储桶写入。
  • 对于较小的工作负载,不需要复制。较大的工作负载可能会受益于复制。
  • 对于存储桶的默认存储类别,请使用“标准”。
  • 将软删除策略设置为最短期限:7 天。
  • 将对象生命周期设置为预期的工作负载运行时间。例如,如果工作负载预计运行 10 天,则将对象生命周期设置为 10 天。这应该涵盖整个运行期间发生的重启。使用age作为生命周期条件,使用Delete作为操作。详情请参见对象生命周期管理。如果未设置对象生命周期,则缓存将继续增长,因为没有实现驱逐机制。
  • 所有加密策略都受支持。

假设gs://jax-cache是 GCS 存储桶,请设置如下cache-location

import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache") 
相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
4月前
|
并行计算 API C++
JAX 中文文档(九)(4)
JAX 中文文档(九)
39 1
|
4月前
|
并行计算 API 异构计算
JAX 中文文档(六)(2)
JAX 中文文档(六)
40 1
|
4月前
|
存储 PyTorch 测试技术
JAX 中文文档(八)(5)
JAX 中文文档(八)
37 0
|
4月前
|
机器学习/深度学习 异构计算 Python
JAX 中文文档(四)(3)
JAX 中文文档(四)
29 0
|
4月前
|
存储 并行计算 数据可视化
JAX 中文文档(六)(3)
JAX 中文文档(六)
30 0
|
4月前
|
机器学习/深度学习 程序员 编译器
JAX 中文文档(三)(1)
JAX 中文文档(三)
37 0
|
4月前
|
并行计算 测试技术 异构计算
JAX 中文文档(一)(5)
JAX 中文文档(一)
76 0
|
4月前
|
机器学习/深度学习 测试技术 索引
JAX 中文文档(二)(4)
JAX 中文文档(二)
49 0
|
4月前
|
API Python
JAX 中文文档(八)(3)
JAX 中文文档(八)
33 0
|
4月前
|
测试技术 API Python
JAX 中文文档(八)(4)
JAX 中文文档(八)
31 0