JAX 中文文档(十三)(2)https://developer.aliyun.com/article/1559742
GPU 内存分配
当第一个 JAX 操作运行时,JAX 将预先分配总 GPU 内存的 75%。 预先分配可以最小化分配开销和内存碎片化,但有时会导致内存不足(OOM)错误。如果您的 JAX 进程因内存不足而失败,可以使用以下环境变量来覆盖默认行为:
XLA_PYTHON_CLIENT_PREALLOCATE=false
这将禁用预分配行为。JAX 将根据需要分配 GPU 内存,可能会减少总体内存使用。但是,这种行为更容易导致 GPU 内存碎片化,这意味着使用大部分可用 GPU 内存的 JAX 程序可能会在禁用预分配时发生 OOM。
XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
如果启用了预分配,这将使 JAX 预分配总 GPU 内存的 XX% ,而不是默认的 75%。减少预分配量可以修复 JAX 程序启动时的内存不足问题。
XLA_PYTHON_CLIENT_ALLOCATOR=platform
这使得 JAX 根据需求精确分配内存,并释放不再需要的内存(请注意,这是唯一会释放 GPU 内存而不是重用它的配置)。这样做非常慢,因此不建议用于一般用途,但可能对于以最小可能的 GPU 内存占用运行或调试 OOM 失败非常有用。
OOM 失败的常见原因
同时运行多个 JAX 进程。
要么使用 XLA_PYTHON_CLIENT_MEM_FRACTION
为每个进程分配适当的内存量,要么设置 XLA_PYTHON_CLIENT_PREALLOCATE=false
。
同时运行 JAX 和 GPU TensorFlow。
TensorFlow 默认也会预分配,因此这与同时运行多个 JAX 进程类似。
一个解决方案是仅使用 CPU TensorFlow(例如,如果您仅使用 TF 进行数据加载)。您可以使用命令 tf.config.experimental.set_visible_devices([], "GPU")
阻止 TensorFlow 使用 GPU。
或者,使用 XLA_PYTHON_CLIENT_MEM_FRACTION
或 XLA_PYTHON_CLIENT_PREALLOCATE
。还有类似的选项可以配置 TensorFlow 的 GPU 内存分配(gpu_memory_fraction 和 allow_growth 在 TF1 中应该设置在传递给 tf.Session
的 tf.ConfigProto
中。参见 使用 GPU:限制 GPU 内存增长 用于 TF2)。
在显示 GPU 上运行 JAX。
使用 XLA_PYTHON_CLIENT_MEM_FRACTION
或 XLA_PYTHON_CLIENT_PREALLOCATE
。
提升秩警告
NumPy 广播规则 允许自动将参数从一个秩(数组轴的数量)提升到另一个秩。当意图明确时,此行为很方便,但也可能导致意外的错误,其中静默的秩提升掩盖了潜在的形状错误。
下面是提升秩的示例:
>>> import numpy as np >>> x = np.arange(12).reshape(4, 3) >>> y = np.array([0, 1, 0]) >>> x + y array([[ 0, 2, 2], [ 3, 5, 5], [ 6, 8, 8], [ 9, 11, 11]])
为了避免潜在的意外,jax.numpy
可配置,以便需要提升秩的表达式会导致警告、错误或像常规 NumPy 一样允许。配置选项名为 jax_numpy_rank_promotion
,可以取字符串值 allow
、warn
和 raise
。默认设置为 allow
,允许提升秩而不警告或错误。设置为 raise
则在提升秩时引发错误,而 warn
在首次提升秩时引发警告。
可以使用 jax.numpy_rank_promotion()
上下文管理器在本地启用或禁用提升秩:
with jax.numpy_rank_promotion("warn"): z = x + y
这个配置也可以在多种全局方式下设置。其中一种是在代码中使用 jax.config
:
import jax jax.config.update("jax_numpy_rank_promotion", "warn")
也可以使用环境变量 JAX_NUMPY_RANK_PROMOTION
来设置选项,例如 JAX_NUMPY_RANK_PROMOTION='warn'
。最后,在使用 absl-py
时,可以使用命令行标志设置选项。
公共 API:jax 包
子包
jax.numpy
模块jax.scipy
模块jax.lax
模块jax.random
模块jax.sharding
模块jax.debug
模块jax.dlpack
模块jax.distributed
模块jax.dtypes
模块jax.flatten_util
模块jax.image
模块jax.nn
模块jax.ops
模块jax.profiler
模块jax.stages
模块jax.tree
模块jax.tree_util
模块jax.typing
模块jax.export
模块jax.extend
模块jax.example_libraries
模块jax.experimental
模块
配置
config |
|
check_tracer_leaks |
jax_check_tracer_leaks 配置选项的上下文管理器。 |
checking_leaks |
jax_check_tracer_leaks 配置选项的上下文管理器。 |
debug_nans |
jax_debug_nans 配置选项的上下文管理器。 |
debug_infs |
jax_debug_infs 配置选项的上下文管理器。 |
default_device |
jax_default_device 配置选项的上下文管理器。 |
default_matmul_precision |
jax_default_matmul_precision 配置选项的上下文管理器。 |
default_prng_impl |
jax_default_prng_impl 配置选项的上下文管理器。 |
enable_checks |
jax_enable_checks 配置选项的上下文管理器。 |
enable_custom_prng |
jax_enable_custom_prng 配置选项的上下文管理器(临时)。 |
enable_custom_vjp_by_custom_transpose |
jax_enable_custom_vjp_by_custom_transpose 配置选项的上下文管理器(临时)。 |
log_compiles |
jax_log_compiles 配置选项的上下文管理器。 |
numpy_rank_promotion |
jax_numpy_rank_promotion 配置选项的上下文管理器。 |
transfer_guard (new_val) |
控制所有传输的传输保护级别的上下文管理器。 |
即时编译 (jit
)
jit (fun[, in_shardings, out_shardings, …]) |
使用 XLA 设置 fun 进行即时编译。 |
disable_jit ([disable]) |
禁用其动态上下文下 jit() 行为的上下文管理器。 |
ensure_compile_time_eval () |
确保在追踪/编译时进行评估的上下文管理器(或错误)。 |
xla_computation (fun[, static_argnums, …]) |
创建一个函数,给定示例参数,产生其 XLA 计算。 |
make_jaxpr ([axis_env, return_shape, …]) |
创建一个函数,给定示例参数,产生其 jaxpr。 |
eval_shape (fun, *args, **kwargs) |
计算 fun 的形状/数据类型,不进行任何 FLOP 计算。 |
ShapeDtypeStruct (shape, dtype[, …]) |
数组的形状、dtype 和其他静态属性的容器。 |
device_put (x[, device, src]) |
将 x 传输到 device 。 |
device_put_replicated (x, devices) |
将数组传输到每个指定的设备并形成数组。 |
device_put_sharded (shards, devices) |
将数组片段传输到指定设备并形成数组。 |
device_get (x) |
将 x 传输到主机。 |
default_backend () |
返回默认 XLA 后端的平台名称。 |
named_call (fun, *[, name]) |
在 JAX 计算中给函数添加用户指定的名称。 |
named_scope (name) |
将用户指定的名称添加到 JAX 名称堆栈的上下文管理器。 |
| block_until_ready
(x) | 尝试调用 pytree 叶子上的 block_until_ready
方法。 | ## 自动微分
grad (fun[, argnums, has_aux, holomorphic, …]) |
创建一个评估 fun 梯度的函数。 |
value_and_grad (fun[, argnums, has_aux, …]) |
创建一个同时评估 fun 和 fun 梯度的函数。 |
jacfwd (fun[, argnums, has_aux, holomorphic]) |
使用正向模式自动微分逐列计算 fun 的雅可比矩阵。 |
jacrev (fun[, argnums, has_aux, holomorphic, …]) |
使用反向模式自动微分逐行计算 fun 的雅可比矩阵。 |
hessian (fun[, argnums, has_aux, holomorphic]) |
fun 的 Hessian 矩阵作为稠密数组。 |
jvp (fun, primals, tangents[, has_aux]) |
计算 fun 的(正向模式)雅可比向量乘积。 |
linearize () |
使用 jvp() 和部分求值生成对 fun 的线性近似。 |
linear_transpose (fun, *primals[, reduce_axes]) |
转置一个承诺为线性的函数。 |
vjp () )) |
计算 fun 的(反向模式)向量-Jacobian 乘积。 |
custom_jvp (fun[, nondiff_argnums]) |
为自定义 JVP 规则定义一个可 JAX 化的函数。 |
custom_vjp (fun[, nondiff_argnums]) |
为自定义 VJP 规则定义一个可 JAX 化的函数。 |
custom_gradient (fun) |
方便地定义自定义的 VJP 规则(即自定义梯度)。 |
closure_convert (fun, *example_args) |
闭包转换实用程序,用于与高阶自定义导数一起使用。 |
checkpoint (fun, *[, prevent_cse, policy, …]) |
使 fun 在求导时重新计算内部线性化点。 |
jax.Array (jax.Array
)
Array () |
JAX 的数组基类 |
make_array_from_callback (shape, sharding, …) |
通过从 data_callback 获取的数据返回一个 jax.Array 。 |
make_array_from_single_device_arrays (shape, …) |
从每个位于单个设备上的 jax.Array 序列返回一个 jax.Array 。 |
make_array_from_process_local_data (sharding, …) |
使用进程中可用的数据创建分布式张量。 |
向量化 (vmap
)
vmap (fun[, in_axes, out_axes, axis_name, …]) |
向量化映射。 |
numpy.vectorize (pyfunc, *[, excluded, signature]) |
定义一个支持广播的向量化函数。 |
并行化 (pmap
)
pmap (fun[, axis_name, in_axes, out_axes, …]) |
支持集体操作的并行映射。 |
devices ([backend]) |
返回给定后端的所有设备列表。 |
local_devices ([process_index, backend, host_id]) |
类似于 jax.devices() ,但仅返回给定进程局部的设备。 |
process_index ([backend]) |
返回此进程的整数进程索引。 |
device_count ([backend]) |
返回设备的总数。 |
local_device_count ([backend]) |
返回此进程可寻址的设备数量。 |
process_count ([backend]) |
返回与后端关联的 JAX 进程数。 |
Callbacks
pure_callback (callback, result_shape_dtypes, …) |
调用一个纯 Python 回调函数。 |
experimental.io_callback (callback, …[, …]) |
调用一个非纯 Python 回调函数。 |
debug.callback (callback, *args[, ordered]) |
调用一个可分期的 Python 回调函数。 |
debug.print (fmt, *args[, ordered]) |
打印值,并在分期 JAX 函数中工作。 |
Miscellaneous
Device |
可用设备的描述符。 |
print_environment_info ([return_string]) |
返回一个包含本地环境和 JAX 安装信息的字符串。 |
live_arrays ([platform]) |
返回后端平台上的所有活动数组。 |
clear_caches () |
清除所有编译和分期缓存。 |
JAX 中文文档(十三)(4)https://developer.aliyun.com/article/1559745