JAX 中文文档(三)(5)

简介: JAX 中文文档(三)

JAX 中文文档(三)(4)https://developer.aliyun.com/article/1559705


JAX 调试标志

原文:jax.readthedocs.io/en/latest/debugging/flags.html

JAX 提供了标志和上下文管理器,可更轻松地捕获错误。

jax_debug_nans 配置选项和上下文管理器

简而言之 启用 jax_debug_nans 标志可自动检测在 jax.jit 编译的代码中产生 NaN(但不适用于 jax.pmapjax.pjit 编译的代码)。

jax_debug_nans 是一个 JAX 标志,当启用时,会在检测到 NaN 时自动引发错误。它对 JIT 编译有特殊处理——如果从 JIT 编译函数检测到 NaN 输出,函数会急切地重新运行(即不经过编译),并在产生 NaN 的具体原始基元处引发错误。

用法

如果您想追踪函数或梯度中出现 NaN 的位置,可以通过以下方式打开 NaN 检查器:

  • 设置 JAX_DEBUG_NANS=True 环境变量;
  • 在主文件顶部附近添加 jax.config.update("jax_debug_nans", True)
  • 在主文件添加 jax.config.parse_flags_with_absl(),然后像 --jax_debug_nans=True 这样使用命令行标志设置选项;

示例

import jax
jax.config.update("jax_debug_nans", True)
def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception! 
jax_debug_nans 的优势和限制
优势
  • 易于应用
  • 精确检测产生 NaN 的位置
  • 抛出标准的 Python 异常,与 PDB 事后调试兼容
限制
  • jax.pmapjax.pjit 不兼容
  • 急切重新运行函数可能会很慢
  • 误报阳性(例如故意创建 NaN)

jax_disable_jit 配置选项和上下文管理器

简而言之 启用 jax_disable_jit 标志可禁用 JIT 编译,从而启用传统的 Python 调试工具如 printpdb

jax_disable_jit 是一个 JAX 标志,当启用时,会在整个 JAX 中禁用 JIT 编译(包括在控制流函数如 jax.lax.condjax.lax.scan 中)。

用法

您可以通过以下方式禁用 JIT 编译:

  • 设置 JAX_DISABLE_JIT=True 环境变量;
  • 在主文件顶部附近添加 jax.config.update("jax_disable_jit", True)
  • 在主文件添加 jax.config.parse_flags_with_absl(),然后像 --jax_disable_jit=True 这样使用命令行标志设置选项;

示例

import jax
jax.config.update("jax_disable_jit", True)
def f(x):
  y = jnp.log(x)
  if jnp.isnan(y):
    breakpoint()
  return y
jax.jit(f)(-2.)  # ==> Enters PDB breakpoint! 
jax_disable_jit 的优势和限制
优势
  • 易于应用
  • 启用 Python 内置的 breakpointprint
  • 抛出标准的 Python 异常,与 PDB 事后调试兼容
限制
  • jax.pmapjax.pjit 不兼容
  • 在没有 JIT 编译的情况下运行函数可能会很慢

GPU 性能提示

原文:jax.readthedocs.io/en/latest/gpu_performance_tips.html

本文档专注于神经网络工作负载的性能提示。

矩阵乘法精度

在像 Nvidia A100 一代或更高的最新 GPU 代中,将大多数计算以 bfloat16 精度执行可能是个好主意。例如,如果使用 Flax,可以使用 flax.linen.Dense(..., dtype=jax.numpy.bfloat16) 实例化 Dense 层。以下是一些代码示例:

XLA 性能标志

注意

JAX-Toolbox 还有一个关于 NVIDIA XLA 性能 FLAGS 的页面。

XLA 标志的存在和确切行为可能取决于 jaxlib 版本。

截至 jaxlib==0.4.18(发布于 2023 年 10 月 6 日),设置这些 XLA 标志可以提高性能。其中一些与多 GPU 之间的通信相关,因此仅在多设备运行计算时才相关,而其他一些与每个设备上的代码生成相关。

未来版本中可能会默认设置其中一些。

这些标志可以通过 XLA_FLAGS shell 环境变量进行设置。例如,我们可以将其添加到 Python 文件的顶部:

import os
os.environ['XLA_FLAGS'] = (
    '--xla_gpu_enable_triton_softmax_fusion=true '
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_async_collectives=true '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
    '--xla_gpu_enable_highest_priority_async_stream=true '
) 

更多示例,请参阅 XLA Flags recommended for Pax training on Nvidia GPUs

代码生成标志

  • –xla_gpu_enable_triton_softmax_fusion 此标志启用基于 Triton 代码生成支持的模式匹配自动 softmax 融合。默认值为 False。
  • –xla_gpu_triton_gemm_any 使用基于 Triton 的 GEMM(矩阵乘法)发射器支持的任何 GEMM。默认值为 False。

通信标志

  • –xla_gpu_enable_async_collectives 此标志启用诸如AllReduceAllGatherReduceScatterCollectivePermute等集体操作以异步方式进行。异步通信可以将跨核心通信与计算重叠。默认值为 False。
  • –xla_gpu_enable_latency_hiding_scheduler 这个标志启用了延迟隐藏调度器,可以高效地将异步通信与计算重叠。默认值为 False。
  • –xla_gpu_enable_pipelined_collectives 在使用管道并行时,此标志允许将(i+1)层权重的AllGather与第 i 层的计算重叠。它还允许将(i+1)层权重的Reduce/ReduceScatter与第 i 层的计算重叠。默认值为 False。在启用此标志时存在一些错误。
  • –xla_gpu_collective_permute_decomposer_threshold 当执行GSPMD pipelining时,这个标志非常有用。设置一个非零的阈值会将CollectivePermute分解为CollectivePermuteReceiveDoneCollectivePermuteSendDone对,从而可以在每个对应的ReceiveDone/SendDone对之间执行计算,从而实现更多的重叠。默认阈值为 0,不进行分解。将其设置为大于 0 的阈值,例如--xla_gpu_collective_permute_decomposer_threshold=1024,可以启用此功能。
  • –xla_gpu_all_gather_combine_threshold_bytes –xla_gpu_reduce_scatter_combine_threshold_bytes –xla_gpu_all_reduce_combine_threshold_bytes 这些标志用于调整何时将多个小的AllGather/ReduceScatter/AllReduce组合成一个大的AllGather/ReduceScatter/AllReduce,以减少跨设备通信所花费的时间。例如,在基于 Transformer 的工作负载上,可以考虑将AllGather/ReduceScatter阈值调高,以至少组合一个 Transformer 层的权重AllGather/ReduceScatter。默认情况下,combine_threshold_bytes设置为 256。

NCCL 标志

这些 Nvidia NCCL 标志值可能对在 Nvidia GPU 上进行单主机多设备计算有用:

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 }) 

这些 NCCL 标志可以提高单主机通信速度。然而,这些标志对多主机通信似乎不太有用。

多进程

我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 jitted 计算。当在 SLURM 下运行时,jax.distributed.initialize() API 将自动理解此配置。然而,这只是一个经验法则,可能有必要在您的用例中测试每个 GPU 一个进程和每个节点一个进程的情况。

持久编译缓存

原文:jax.readthedocs.io/en/latest/persistent_compilation_cache.html

JAX 具有可选的磁盘缓存用于编译程序。如果启用,JAX 将在磁盘上存储编译程序的副本,这在重复运行相同或类似任务时可以节省重新编译时间。

使用

当设置了cache-location时,编译缓存将启用。这应在第一次编译之前完成。设置位置如下:

import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location") 

有关cache-location的更多详细信息,请参见以下各节。

set_cache_dir()是设置cache-location的另一种方法。

本地文件系统

cache-location可以是本地文件系统上的目录。例如:

import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache") 

注意:缓存没有实现驱逐机制。如果cache-location是本地文件系统中的目录,则其大小将继续增长,除非手动删除文件。

Google Cloud

在 Google Cloud 上运行时,可以将编译缓存放置在 Google Cloud Storage(GCS)存储桶中。我们建议采用以下配置:

  • 在与工作负载运行地区相同的地方创建存储桶。
  • 在与工作负载的 VM 相同的项目中创建存储桶。确保设置了权限,使 VM 能够向存储桶写入。
  • 对于较小的工作负载,不需要复制。较大的工作负载可能会受益于复制。
  • 对于存储桶的默认存储类别,请使用“标准”。
  • 将软删除策略设置为最短期限:7 天。
  • 将对象生命周期设置为预期的工作负载运行时间。例如,如果工作负载预计运行 10 天,则将对象生命周期设置为 10 天。这应该涵盖整个运行期间发生的重启。使用age作为生命周期条件,使用Delete作为操作。详情请参见对象生命周期管理。如果未设置对象生命周期,则缓存将继续增长,因为没有实现驱逐机制。
  • 所有加密策略都受支持。

假设gs://jax-cache是 GCS 存储桶,请设置如下cache-location

import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache") 

tter。默认情况下,combine_threshold_bytes`设置为 256。

NCCL 标志

这些 Nvidia NCCL 标志值可能对在 Nvidia GPU 上进行单主机多设备计算有用:

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 }) 

这些 NCCL 标志可以提高单主机通信速度。然而,这些标志对多主机通信似乎不太有用。

多进程

我们建议每个 GPU 使用一个进程,而不是每个节点使用一个进程。在某些情况下,这可以加速 jitted 计算。当在 SLURM 下运行时,jax.distributed.initialize() API 将自动理解此配置。然而,这只是一个经验法则,可能有必要在您的用例中测试每个 GPU 一个进程和每个节点一个进程的情况。

持久编译缓存

原文:jax.readthedocs.io/en/latest/persistent_compilation_cache.html

JAX 具有可选的磁盘缓存用于编译程序。如果启用,JAX 将在磁盘上存储编译程序的副本,这在重复运行相同或类似任务时可以节省重新编译时间。

使用

当设置了cache-location时,编译缓存将启用。这应在第一次编译之前完成。设置位置如下:

import jax
# Make sure this is called before jax runs any operations!
jax.config.update("jax_compilation_cache_dir", "cache-location") 

有关cache-location的更多详细信息,请参见以下各节。

set_cache_dir()是设置cache-location的另一种方法。

本地文件系统

cache-location可以是本地文件系统上的目录。例如:

import jax
jax.config.update("jax_compilation_cache_dir", "/tmp/jax-cache") 

注意:缓存没有实现驱逐机制。如果cache-location是本地文件系统中的目录,则其大小将继续增长,除非手动删除文件。

Google Cloud

在 Google Cloud 上运行时,可以将编译缓存放置在 Google Cloud Storage(GCS)存储桶中。我们建议采用以下配置:

  • 在与工作负载运行地区相同的地方创建存储桶。
  • 在与工作负载的 VM 相同的项目中创建存储桶。确保设置了权限,使 VM 能够向存储桶写入。
  • 对于较小的工作负载,不需要复制。较大的工作负载可能会受益于复制。
  • 对于存储桶的默认存储类别,请使用“标准”。
  • 将软删除策略设置为最短期限:7 天。
  • 将对象生命周期设置为预期的工作负载运行时间。例如,如果工作负载预计运行 10 天,则将对象生命周期设置为 10 天。这应该涵盖整个运行期间发生的重启。使用age作为生命周期条件,使用Delete作为操作。详情请参见对象生命周期管理。如果未设置对象生命周期,则缓存将继续增长,因为没有实现驱逐机制。
  • 所有加密策略都受支持。

假设gs://jax-cache是 GCS 存储桶,请设置如下cache-location

import jax
jax.config.update("jax_compilation_cache_dir", "gs://jax-cache") 
相关实践学习
在云上部署ChatGLM2-6B大模型(GPU版)
ChatGLM2-6B是由智谱AI及清华KEG实验室于2023年6月发布的中英双语对话开源大模型。通过本实验,可以学习如何配置AIGC开发环境,如何部署ChatGLM2-6B大模型。
相关文章
|
运维 资源调度 Kubernetes
Kubernetes Scheduler Framework 扩展: 1. Coscheduling
# 前言 ## 为什么Kubernetes需要Coscheduling功能? Kubernetes目前已经广泛的应用于在线服务编排,为了提升集群的的利用率和运行效率,我们希望将Kubernetes作为一个统一的管理平台来管理在线服务和离线作业。但是默认的调度器是以Pod为调度单元进行依次调度,不会考虑Pod之间的相互关系。但是很多数据计算类的作业具有All-or-Nothing特点,要求所有的
3630 0
|
Linux Shell 网络安全
【Shell 命令集合 网络通讯 】Linux 与SMB服务器进行交互 smbclient命令 使用指南
【Shell 命令集合 网络通讯 】Linux 与SMB服务器进行交互 smbclient命令 使用指南
1073 1
|
SQL 数据可视化 数据处理
使用SQL和Python处理Excel文件数据
使用SQL和Python处理Excel文件数据
980 0
|
机器学习/深度学习 PyTorch 编译器
深入解析torch.compile:提升PyTorch模型性能、高效解决常见问题
PyTorch 2.0推出的`torch.compile`功能为深度学习模型带来了显著的性能优化能力。本文从实用角度出发,详细介绍了`torch.compile`的核心技巧与应用场景,涵盖模型复杂度评估、可编译组件分析、系统化调试策略及性能优化高级技巧等内容。通过解决图断裂、重编译频繁等问题,并结合分布式训练和NCCL通信优化,开发者可以有效提升日常开发效率与模型性能。文章为PyTorch用户提供了全面的指导,助力充分挖掘`torch.compile`的潜力。
1396 17
|
安全 Go 调度
Go同步原语与数据竞争:原子操作(atomic)
本文介绍了Go语言中`sync/atomic`包的使用,帮助避免多goroutine并发操作时的数据竞争问题。原子操作是一种不可中断的操作,确保变量读写的安全性。文章详细说明了常用函数如`Load`、`Store`、`Add`和`CompareAndSwap`的功能与应用场景,并通过并发计数器示例展示了其实现方式。此外,对比了原子操作与锁的优缺点,强调原子操作适用于简单变量的高效同步,而不适合复杂数据结构。最后提醒开发者注意使用场景限制,合理选择同步工具以优化性能。
|
缓存 安全 网络协议
|
机器学习/深度学习 人工智能 PyTorch
【Hello AI】神龙AI加速引擎AIACC-加速深度学习应用
神龙AI加速引擎AIACC是基于阿里云IaaS资源推出的AI加速引擎,用于优化基于AI主流计算框架搭建的模型,使用AIACC可加速深度学习应用,能显著提升模型的训练和推理性能。
|
弹性计算
阿里云服务器公网带宽最高只能是100Mbps吗?
阿里云服务器公网带宽最高只能是100Mbps吗?阿里云服务器公网带宽最高只有100Mbps,不够用怎么办?阿里云的带宽居然只有100M,对于需要大带宽云服务器的用户怎么办?难道只能租物理机吗?并不是,可以将云服务器固定公网IP转成弹性公网EIP,弹性公网IP带宽最高可选1000Mbps,完全够用
阿里云服务器公网带宽最高只能是100Mbps吗?
|
分布式计算 监控 Oracle
Spark Standalone环境搭建及测试
Spark Standalone环境搭建及测试
444 0
|
机器学习/深度学习 算法 C++
动手强化学习(九):策略梯度算法
 首先定义策略网络PolicyNet,其输入是某个状态,输出则是该状态下的动作概率分布,这里采用在离散动作空间上的softmax()函数来实现一个可学习的多项分布(multinomial distribution)。
1094 0