JAX 中文文档(十三)(3)

简介: JAX 中文文档(十三)

JAX 中文文档(十三)(2)https://developer.aliyun.com/article/1559742


GPU 内存分配

原文:jax.readthedocs.io/en/latest/gpu_memory_allocation.html

当第一个 JAX 操作运行时,JAX 将预先分配总 GPU 内存的 75%。 预先分配可以最小化分配开销和内存碎片化,但有时会导致内存不足(OOM)错误。如果您的 JAX 进程因内存不足而失败,可以使用以下环境变量来覆盖默认行为:

XLA_PYTHON_CLIENT_PREALLOCATE=false

这将禁用预分配行为。JAX 将根据需要分配 GPU 内存,可能会减少总体内存使用。但是,这种行为更容易导致 GPU 内存碎片化,这意味着使用大部分可用 GPU 内存的 JAX 程序可能会在禁用预分配时发生 OOM。

XLA_PYTHON_CLIENT_MEM_FRACTION=.XX

如果启用了预分配,这将使 JAX 预分配总 GPU 内存的 XX% ,而不是默认的 75%。减少预分配量可以修复 JAX 程序启动时的内存不足问题。

XLA_PYTHON_CLIENT_ALLOCATOR=platform

这使得 JAX 根据需求精确分配内存,并释放不再需要的内存(请注意,这是唯一会释放 GPU 内存而不是重用它的配置)。这样做非常慢,因此不建议用于一般用途,但可能对于以最小可能的 GPU 内存占用运行或调试 OOM 失败非常有用。

OOM 失败的常见原因

同时运行多个 JAX 进程。

要么使用 XLA_PYTHON_CLIENT_MEM_FRACTION 为每个进程分配适当的内存量,要么设置 XLA_PYTHON_CLIENT_PREALLOCATE=false

同时运行 JAX 和 GPU TensorFlow。

TensorFlow 默认也会预分配,因此这与同时运行多个 JAX 进程类似。

一个解决方案是仅使用 CPU TensorFlow(例如,如果您仅使用 TF 进行数据加载)。您可以使用命令 tf.config.experimental.set_visible_devices([], "GPU") 阻止 TensorFlow 使用 GPU。

或者,使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE。还有类似的选项可以配置 TensorFlow 的 GPU 内存分配(gpu_memory_fractionallow_growth 在 TF1 中应该设置在传递给 tf.Sessiontf.ConfigProto 中。参见 使用 GPU:限制 GPU 内存增长 用于 TF2)。

在显示 GPU 上运行 JAX。

使用 XLA_PYTHON_CLIENT_MEM_FRACTIONXLA_PYTHON_CLIENT_PREALLOCATE

提升秩警告

原文:jax.readthedocs.io/en/latest/rank_promotion_warning.html

NumPy 广播规则 允许自动将参数从一个秩(数组轴的数量)提升到另一个秩。当意图明确时,此行为很方便,但也可能导致意外的错误,其中静默的秩提升掩盖了潜在的形状错误。

下面是提升秩的示例:

>>> import numpy as np
>>> x = np.arange(12).reshape(4, 3)
>>> y = np.array([0, 1, 0])
>>> x + y
array([[ 0,  2,  2],
 [ 3,  5,  5],
 [ 6,  8,  8],
 [ 9, 11, 11]]) 

为了避免潜在的意外,jax.numpy 可配置,以便需要提升秩的表达式会导致警告、错误或像常规 NumPy 一样允许。配置选项名为 jax_numpy_rank_promotion,可以取字符串值 allowwarnraise。默认设置为 allow,允许提升秩而不警告或错误。设置为 raise 则在提升秩时引发错误,而 warn 在首次提升秩时引发警告。

可以使用 jax.numpy_rank_promotion() 上下文管理器在本地启用或禁用提升秩:

with jax.numpy_rank_promotion("warn"):
  z = x + y 

这个配置也可以在多种全局方式下设置。其中一种是在代码中使用 jax.config

import jax
jax.config.update("jax_numpy_rank_promotion", "warn") 

也可以使用环境变量 JAX_NUMPY_RANK_PROMOTION 来设置选项,例如 JAX_NUMPY_RANK_PROMOTION='warn'。最后,在使用 absl-py 时,可以使用命令行标志设置选项。

公共 API:jax 包

原文:jax.readthedocs.io/en/latest/jax.html

子包

  • jax.numpy 模块
  • jax.scipy 模块
  • jax.lax 模块
  • jax.random 模块
  • jax.sharding 模块
  • jax.debug 模块
  • jax.dlpack 模块
  • jax.distributed 模块
  • jax.dtypes 模块
  • jax.flatten_util 模块
  • jax.image 模块
  • jax.nn 模块
  • jax.ops 模块
  • jax.profiler 模块
  • jax.stages 模块
  • jax.tree 模块
  • jax.tree_util 模块
  • jax.typing 模块
  • jax.export 模块
  • jax.extend 模块
  • jax.example_libraries 模块
  • jax.experimental 模块

配置

config
check_tracer_leaks jax_check_tracer_leaks 配置选项的上下文管理器。
checking_leaks jax_check_tracer_leaks 配置选项的上下文管理器。
debug_nans jax_debug_nans 配置选项的上下文管理器。
debug_infs jax_debug_infs 配置选项的上下文管理器。
default_device jax_default_device 配置选项的上下文管理器。
default_matmul_precision jax_default_matmul_precision 配置选项的上下文管理器。
default_prng_impl jax_default_prng_impl 配置选项的上下文管理器。
enable_checks jax_enable_checks 配置选项的上下文管理器。
enable_custom_prng jax_enable_custom_prng 配置选项的上下文管理器(临时)。
enable_custom_vjp_by_custom_transpose jax_enable_custom_vjp_by_custom_transpose 配置选项的上下文管理器(临时)。
log_compiles jax_log_compiles 配置选项的上下文管理器。
numpy_rank_promotion jax_numpy_rank_promotion 配置选项的上下文管理器。
transfer_guard(new_val) 控制所有传输的传输保护级别的上下文管理器。

即时编译 (jit)

jit(fun[, in_shardings, out_shardings, …]) 使用 XLA 设置 fun 进行即时编译。
disable_jit([disable]) 禁用其动态上下文下 jit() 行为的上下文管理器。
ensure_compile_time_eval() 确保在追踪/编译时进行评估的上下文管理器(或错误)。
xla_computation(fun[, static_argnums, …]) 创建一个函数,给定示例参数,产生其 XLA 计算。
make_jaxpr([axis_env, return_shape, …]) 创建一个函数,给定示例参数,产生其 jaxpr。
eval_shape(fun, *args, **kwargs) 计算 fun 的形状/数据类型,不进行任何 FLOP 计算。
ShapeDtypeStruct(shape, dtype[, …]) 数组的形状、dtype 和其他静态属性的容器。
device_put(x[, device, src]) x 传输到 device
device_put_replicated(x, devices) 将数组传输到每个指定的设备并形成数组。
device_put_sharded(shards, devices) 将数组片段传输到指定设备并形成数组。
device_get(x) x 传输到主机。
default_backend() 返回默认 XLA 后端的平台名称。
named_call(fun, *[, name]) 在 JAX 计算中给函数添加用户指定的名称。
named_scope(name) 将用户指定的名称添加到 JAX 名称堆栈的上下文管理器。

| block_until_ready(x) | 尝试调用 pytree 叶子上的 block_until_ready 方法。 | ## 自动微分

grad(fun[, argnums, has_aux, holomorphic, …]) 创建一个评估 fun 梯度的函数。
value_and_grad(fun[, argnums, has_aux, …]) 创建一个同时评估 funfun 梯度的函数。
jacfwd(fun[, argnums, has_aux, holomorphic]) 使用正向模式自动微分逐列计算 fun 的雅可比矩阵。
jacrev(fun[, argnums, has_aux, holomorphic, …]) 使用反向模式自动微分逐行计算 fun 的雅可比矩阵。
hessian(fun[, argnums, has_aux, holomorphic]) fun 的 Hessian 矩阵作为稠密数组。
jvp(fun, primals, tangents[, has_aux]) 计算 fun 的(正向模式)雅可比向量乘积。
linearize() 使用 jvp() 和部分求值生成对 fun 的线性近似。
linear_transpose(fun, *primals[, reduce_axes]) 转置一个承诺为线性的函数。
vjp() )) 计算 fun 的(反向模式)向量-Jacobian 乘积。
custom_jvp(fun[, nondiff_argnums]) 为自定义 JVP 规则定义一个可 JAX 化的函数。
custom_vjp(fun[, nondiff_argnums]) 为自定义 VJP 规则定义一个可 JAX 化的函数。
custom_gradient(fun) 方便地定义自定义的 VJP 规则(即自定义梯度)。
closure_convert(fun, *example_args) 闭包转换实用程序,用于与高阶自定义导数一起使用。
checkpoint(fun, *[, prevent_cse, policy, …]) 使 fun 在求导时重新计算内部线性化点。

jax.Array (jax.Array)

Array() JAX 的数组基类
make_array_from_callback(shape, sharding, …) 通过从 data_callback 获取的数据返回一个 jax.Array
make_array_from_single_device_arrays(shape, …) 从每个位于单个设备上的 jax.Array 序列返回一个 jax.Array
make_array_from_process_local_data(sharding, …) 使用进程中可用的数据创建分布式张量。

向量化 (vmap)

vmap(fun[, in_axes, out_axes, axis_name, …]) 向量化映射。
numpy.vectorize(pyfunc, *[, excluded, signature]) 定义一个支持广播的向量化函数。

并行化 (pmap)

pmap(fun[, axis_name, in_axes, out_axes, …]) 支持集体操作的并行映射。
devices([backend]) 返回给定后端的所有设备列表。
local_devices([process_index, backend, host_id]) 类似于 jax.devices(),但仅返回给定进程局部的设备。
process_index([backend]) 返回此进程的整数进程索引。
device_count([backend]) 返回设备的总数。
local_device_count([backend]) 返回此进程可寻址的设备数量。
process_count([backend]) 返回与后端关联的 JAX 进程数。

Callbacks

pure_callback(callback, result_shape_dtypes, …) 调用一个纯 Python 回调函数。
experimental.io_callback(callback, …[, …]) 调用一个非纯 Python 回调函数。
debug.callback(callback, *args[, ordered]) 调用一个可分期的 Python 回调函数。
debug.print(fmt, *args[, ordered]) 打印值,并在分期 JAX 函数中工作。

Miscellaneous

Device 可用设备的描述符。
print_environment_info([return_string]) 返回一个包含本地环境和 JAX 安装信息的字符串。
live_arrays([platform]) 返回后端平台上的所有活动数组。
clear_caches() 清除所有编译和分期缓存。


JAX 中文文档(十三)(4)https://developer.aliyun.com/article/1559745

相关实践学习
部署Stable Diffusion玩转AI绘画(GPU云服务器)
本实验通过在ECS上从零开始部署Stable Diffusion来进行AI绘画创作,开启AIGC盲盒。
相关文章
|
4月前
|
机器学习/深度学习 编译器 API
JAX 中文文档(十三)(4)
JAX 中文文档(十三)
56 2
|
4月前
|
算法 Serverless 索引
JAX 中文文档(十三)(5)
JAX 中文文档(十三)
29 1
|
4月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
53 3
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
40 3
|
4月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
33 1
|
4月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
17 1
|
4月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
55 1
|
4月前
|
测试技术 API 调度
JAX 中文文档(十三)(2)
JAX 中文文档(十三)
29 0
|
4月前
|
机器学习/深度学习 Shell API
JAX 中文文档(十三)(1)
JAX 中文文档(十三)
41 0
|
4月前
|
安全 算法 API
JAX 中文文档(十一)(3)
JAX 中文文档(十一)
29 0