JAX 中文文档(三)(4)https://developer.aliyun.com/article/1559705
JAX 调试标志
JAX 提供了标志和上下文管理器,可更轻松地捕获错误。
jax_debug_nans
配置选项和上下文管理器
简而言之 启用 jax_debug_nans
标志可自动检测在 jax.jit
编译的代码中产生 NaN(但不适用于 jax.pmap
或 jax.pjit
编译的代码)。
jax_debug_nans
是一个 JAX 标志,当启用时,会在检测到 NaN 时自动引发错误。它对 JIT 编译有特殊处理——如果从 JIT 编译函数检测到 NaN 输出,函数会急切地重新运行(即不经过编译),并在产生 NaN 的具体原始基元处引发错误。
用法
如果您想追踪函数或梯度中出现 NaN 的位置,可以通过以下方式打开 NaN 检查器:
- 设置
JAX_DEBUG_NANS=True
环境变量; - 在主文件顶部附近添加
jax.config.update("jax_debug_nans", True)
; - 在主文件添加
jax.config.parse_flags_with_absl()
,然后像--jax_debug_nans=True
这样使用命令行标志设置选项;
示例
import jax jax.config.update("jax_debug_nans", True) def f(x, y): return x / y jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception!
jax_debug_nans
的优势和限制
优势
- 易于应用
- 精确检测产生 NaN 的位置
- 抛出标准的 Python 异常,与 PDB 事后调试兼容
限制
- 与
jax.pmap
或jax.pjit
不兼容 - 急切重新运行函数可能会很慢
- 误报阳性(例如故意创建 NaN)
jax_disable_jit
配置选项和上下文管理器
简而言之 启用 jax_disable_jit
标志可禁用 JIT 编译,从而启用传统的 Python 调试工具如 print
和 pdb
。
jax_disable_jit
是一个 JAX 标志,当启用时,会在整个 JAX 中禁用 JIT 编译(包括在控制流函数如 jax.lax.cond
和 jax.lax.scan
中)。
用法
您可以通过以下方式禁用 JIT 编译:
- 设置
JAX_DISABLE_JIT=True
环境变量; - 在主文件顶部附近添加
jax.config.update("jax_disable_jit", True)
; - 在主文件添加
jax.config.parse_flags_with_absl()
,然后像--jax_disable_jit=True
这样使用命令行标志设置选项;
示例
import jax jax.config.update("jax_disable_jit", True) def f(x): y = jnp.log(x) if jnp.isnan(y): breakpoint() return y jax.jit(f)(-2.) # ==> Enters PDB breakpoint!
jax_disable_jit
的优势和限制
优势
- 易于应用
- 启用 Python 内置的
breakpoint
和print
- 抛出标准的 Python 异常,与 PDB 事后调试兼容
限制
- 与
jax.pmap
或jax.pjit
不兼容 - 在没有 JIT 编译的情况下运行函数可能会很慢
GPU 性能提示
本文档专注于神经网络工作负载的性能提示。
矩阵乘法精度
在像 Nvidia A100 一代或更高的最新 GPU 代中,将大多数计算以 bfloat16
精度执行可能是个好主意。例如,如果使用 Flax,可以使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16)
实例化 Dense
层。以下是一些代码示例:
- 在 Flax LM1B example 中,
Dense
模块也可以使用可配置的数据类型 进行实例化,其 默认值 为 bfloat16。 - 在 MaxText 中,
DenseGeneral
模块也可用可配置的数据类型 进行实例化,其 默认值为 bfloat16。
XLA 性能标志
注意
JAX-Toolbox 还有一个关于 NVIDIA XLA 性能 FLAGS 的页面。
XLA 标志的存在和确切行为可能取决于 jaxlib
版本。
截至 jaxlib==0.4.18
(发布于 2023 年 10 月 6 日),设置这些 XLA 标志可以提高性能。其中一些与多 GPU 之间的通信相关,因此仅在多设备运行计算时才相关,而其他一些与每个设备上的代码生成相关。
未来版本中可能会默认设置其中一些。
这些标志可以通过 XLA_FLAGS
shell 环境变量进行设置。例如,我们可以将其添加到 Python 文件的顶部:
import os os.environ['XLA_FLAGS'] = ( '--xla_gpu_enable_triton_softmax_fusion=true ' '--xla_gpu_triton_gemm_any=True ' '--xla_gpu_enable_async_collectives=true ' '--xla_gpu_enable_latency_hiding_scheduler=true ' '--xla_gpu_enable_highest_priority_async_stream=true ' )
更多示例,请参阅 XLA Flags recommended for Pax training on Nvidia GPUs。
代码生成标志
- –xla_gpu_enable_triton_softmax_fusion 此标志启用基于 Triton 代码生成支持的模式匹配自动 softmax 融合。默认值为 False。
- –xla_gpu_triton_gemm_any 使用基于 Triton 的 GEMM(矩阵乘法)发射器支持的任何 GEMM。默认值为 False。
通信标志
- –xla_gpu_enable_async_collectives 此标志启用诸如
AllReduce
、AllGather
、ReduceScatter
和CollectivePermute
等集体操作以异步方式进行。异步通信可以将跨核心通信与计算重叠。默认值为 False。 - –xla_gpu_enable_latency_hiding_scheduler 这个标志启用了延迟隐藏调度器,可以高效地将异步通信与计算重叠。默认值为 False。
- –xla_gpu_enable_pipelined_collectives 在使用管道并行时,此标志允许将(i+1)层权重的
AllGather
与第 i 层的计算重叠。它还允许将(i+1)层权重的Reduce
/ReduceScatter
与第 i 层的计算重叠。默认值为 False。在启用此标志时存在一些错误。 - –xla_gpu_collective_permute_decomposer_threshold 当执行GSPMD pipelining时,这个标志非常有用。设置一个非零的阈值会将
CollectivePermute
分解为CollectivePermuteReceiveDone
和CollectivePermuteSendDone
对,从而可以在每个对应的ReceiveDone
/SendDone
对之间执行计算,从而实现更多的重叠。默认阈值为 0,不进行分解。将其设置为大于 0 的阈值,例如--xla_gpu_collective_permute_decomposer_threshold=1024
,可以启用此功能。 - –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志用于调整何时将多个小的
AllGather
/ReduceScatter
/AllReduce
组合成一个大的AllGather
/ReduceScatter
/AllReduce
,以减少跨设备通信所花费的时间。例如,在基于 Transformer 的工作负载上,可以考虑将AllGather
/ReduceScatter
阈值调高,以至少组合一个 Transformer 层的权重AllGather
/ReduceScatter
。默认情况下,combine_threshold_bytes
设置为 256。
NCCL 标志
这些 Nvidia NCCL 标志值可能对在 Nvidia GPU 上进行单主机多设备计算有用:
os.environ.update({ "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", })
这些 NCCL 标志可以提高单主机通信速度。然而,这些标志对多主机通信似乎不太有用。
多进程
我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 jitted 计算。当在 SLURM 下运行时,jax.distributed.initialize()
API 将自动理解此配置。然而,这只是一个经验法则,可能有必要在您的用例中测试每个 GPU 一个进程和每个节点一个进程的情况。
持久编译缓存
原文:
jax.readthedocs.io/en/latest/persistent_compilation_cache.html
JAX 具有可选的磁盘缓存用于编译程序。如果启用,JAX 将在磁盘上存储编译程序的副本,这在重复运行相同或类似任务时可以节省重新编译时间。
使用
当设置了cache-location时,编译缓存将启用。这应在第一次编译之前完成。设置位置如下:
import jax # Make sure this is called before jax runs any operations! jax.config.update("jax_compilation_cache_dir", "cache-location")
有关cache-location
的更多详细信息,请参见以下各节。
set_cache_dir()
是设置cache-location
的另一种方法。
本地文件系统
cache-location
可以是本地文件系统上的目录。例如:
import jax jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
注意:缓存没有实现驱逐机制。如果cache-location
是本地文件系统中的目录,则其大小将继续增长,除非手动删除文件。
Google Cloud
在 Google Cloud 上运行时,可以将编译缓存放置在 Google Cloud Storage(GCS)存储桶中。我们建议采用以下配置:
- 在与工作负载运行地区相同的地方创建存储桶。
- 在与工作负载的 VM 相同的项目中创建存储桶。确保设置了权限,使 VM 能够向存储桶写入。
- 对于较小的工作负载,不需要复制。较大的工作负载可能会受益于复制。
- 对于存储桶的默认存储类别,请使用“标准”。
- 将软删除策略设置为最短期限:7 天。
- 将对象生命周期设置为预期的工作负载运行时间。例如,如果工作负载预计运行 10 天,则将对象生命周期设置为 10 天。这应该涵盖整个运行期间发生的重启。使用
age
作为生命周期条件,使用Delete
作为操作。详情请参见对象生命周期管理。如果未设置对象生命周期,则缓存将继续增长,因为没有实现驱逐机制。 - 所有加密策略都受支持。
假设gs://jax-cache
是 GCS 存储桶,请设置如下cache-location
:
import jax jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")
tter。默认情况下,
combine_threshold_bytes`设置为 256。
NCCL 标志
这些 Nvidia NCCL 标志值可能对在 Nvidia GPU 上进行单主机多设备计算有用:
os.environ.update({ "NCCL_LL128_BUFFSIZE": "-2", "NCCL_LL_BUFFSIZE": "-2", "NCCL_PROTO": "SIMPLE,LL,LL128", })
这些 NCCL 标志可以提高单主机通信速度。然而,这些标志对多主机通信似乎不太有用。
多进程
我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 jitted 计算。当在 SLURM 下运行时,jax.distributed.initialize()
API 将自动理解此配置。然而,这只是一个经验法则,可能有必要在您的用例中测试每个 GPU 一个进程和每个节点一个进程的情况。
持久编译缓存
原文:
jax.readthedocs.io/en/latest/persistent_compilation_cache.html
JAX 具有可选的磁盘缓存用于编译程序。如果启用,JAX 将在磁盘上存储编译程序的副本,这在重复运行相同或类似任务时可以节省重新编译时间。
使用
当设置了cache-location时,编译缓存将启用。这应在第一次编译之前完成。设置位置如下:
import jax # Make sure this is called before jax runs any operations! jax.config.update("jax_compilation_cache_dir", "cache-location")
有关cache-location
的更多详细信息,请参见以下各节。
set_cache_dir()
是设置cache-location
的另一种方法。
本地文件系统
cache-location
可以是本地文件系统上的目录。例如:
import jax jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache")
注意:缓存没有实现驱逐机制。如果cache-location
是本地文件系统中的目录,则其大小将继续增长,除非手动删除文件。
Google Cloud
在 Google Cloud 上运行时,可以将编译缓存放置在 Google Cloud Storage(GCS)存储桶中。我们建议采用以下配置:
- 在与工作负载运行地区相同的地方创建存储桶。
- 在与工作负载的 VM 相同的项目中创建存储桶。确保设置了权限,使 VM 能够向存储桶写入。
- 对于较小的工作负载,不需要复制。较大的工作负载可能会受益于复制。
- 对于存储桶的默认存储类别,请使用“标准”。
- 将软删除策略设置为最短期限:7 天。
- 将对象生命周期设置为预期的工作负载运行时间。例如,如果工作负载预计运行 10 天,则将对象生命周期设置为 10 天。这应该涵盖整个运行期间发生的重启。使用
age
作为生命周期条件,使用Delete
作为操作。详情请参见对象生命周期管理。如果未设置对象生命周期,则缓存将继续增长,因为没有实现驱逐机制。 - 所有加密策略都受支持。
假设gs://jax-cache
是 GCS 存储桶,请设置如下cache-location
:
import jax jax.config.update("jax_compilation_cache_dir", "gs://jax-cache")