jax.experimental.sparse.bcoo_multiply_dense
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_multiply_dense.html
jax.experimental.sparse.bcoo_multiply_dense(sp_mat, v)
稀疏数组和稠密数组之间的逐元素乘法。
参数:
- lhs – 一个 BCOO 格式的数组。
- rhs – 一个 ndarray。
- sp_mat(BCOO)
- v(Array)
返回:
包含结果的 ndarray。
返回类型:
Array
jax.experimental.sparse.bcoo_multiply_sparse
jax.experimental.sparse.bcoo_multiply_sparse(lhs, rhs)
两个稀疏数组的逐元素乘积。
参数:
- lhs (BCOO) – 一个 BCOO 格式的数组。
- rhs (BCOO) – 一个 BCOO 格式的数组。
返回值:
包含结果的 BCOO 格式数组。
返回类型:
BCOO
jax.experimental.sparse.bcoo_update_layout
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_update_layout.html
jax.experimental.sparse.bcoo_update_layout(mat, *, n_batch=None, n_dense=None, on_inefficient='error')
更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。
在许多情况下,可以在不引入不必要的存储开销的情况下完成此操作。然而,增加 mat.n_batch
或 mat.n_dense
将导致存储效率非常低下,许多零值都是显式存储的,除非新的批处理或密集维度的大小为 0 或 1。在这种情况下,bcoo_update_layout
将引发 SparseEfficiencyError
。可以通过指定 on_inefficient
参数来消除此警告。
参数:
- mat(BCOO) – BCOO 数组
- n_batch(int | None) – 可选参数(整数),输出矩阵中批处理维度的数量。如果为 None,则 n_batch = mat.n_batch。
- n_dense(int | None) – 可选参数(整数),输出矩阵中密集维度的数量。如果为 None,则 n_dense = mat.n_dense。
- on_inefficient(str | None) – 可选参数(字符串),其中之一
['error', 'warn', None]
。指定在重新配置效率低下的情况下的行为。这被定义为结果表示的大小远大于输入表示的情况。
返回:
BCOO 数组
表示与输入相同的稀疏数组的 BCOO 数组,具有指定的布局。 mat_out.todense()
将与 mat.todense()
在适当的精度上匹配。
返回类型:
mat_out
jax.experimental.sparse.bcoo_reduce_sum
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_reduce_sum.html
jax.experimental.sparse.bcoo_reduce_sum(mat, *, axes)
对给定轴上的数组元素求和。
参数:
返回:
包含结果的 BCOO 格式数组。
返回类型:
BCOO
jax.experimental.sparse.bcoo_reshape
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_reshape.html
jax.experimental.sparse.bcoo_reshape(mat, *, new_sizes, dimensions=None)
稀疏实现的{func}jax.lax.reshape
。
参数:
- operand – 待重塑的 BCOO 数组。
- new_sizes (Sequence[int]) – 指定结果形状的整数序列。最终数组的大小必须与输入的大小相匹配。这必须指定为批量、稀疏和密集维度不混合的形式。
- dimensions (Sequence[int] | None) – 可选的整数序列,指定输入形状的排列顺序。如果指定,长度必须与
operand.shape
相匹配。此外,维度必须仅在 mat 的相似维度之间进行排列:批量、稀疏和密集维度不能混合排列。 - mat (BCOO)
返回:
重塑后的数组。
返回类型:
输出
jax.experimental.sparse.bcoo_slice
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_slice.html
jax.experimental.sparse.bcoo_slice(mat, *, start_indices, limit_indices, strides=None)
{func}jax.lax.slice
的稀疏实现。
参数:
- mat (BCOO) – 待重新形状的 BCOO 数组。
- 起始索引 (Sequence[int]) – 长度为 mat.ndim 的整数序列,指定每个切片的起始索引。
- 限制索引 (Sequence[int]) – 长度为 mat.ndim 的整数序列,指定每个切片的结束索引
- 步幅 (Sequence[int] | None) – (未实现) 长度为 mat.ndim 的整数序列,指定每个切片的步幅
返回:
包含切片的 BCOO 数组。
返回类型:
输出
jax.experimental.sparse.bcoo_sort_indices
链接:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_sort_indices.html
jax.experimental.sparse.bcoo_sort_indices(mat)
排序一个 BCOO 数组的索引。
参数:
mat(BCOO)– BCOO 数组
返回:
带有已排序索引的 BCOO 数组。
返回类型:
mat_out
jax.experimental.sparse.bcoo_squeeze
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_squeeze.html
jax.experimental.sparse.bcoo_squeeze(arr, *, dimensions)
{func}jax.lax.squeeze
的稀疏实现。
从数组中挤出任意数量的大小为 1 的维度。
参数:
返回:
重新塑形的数组。
返回类型:
out
jax.experimental.sparse.bcoo_sum_duplicates
jax.experimental.sparse.bcoo_sum_duplicates(mat, nse=None)
对 BCOO 数组内的重复索引求和,返回一个带有排序索引的数组。
参数:
- mat (BCOO) – BCOO 数组
- nse (int | None) – 整数(可选)。输出矩阵中指定元素的数量。这必须指定以使 bcoo_sum_duplicates 兼容 JIT 和其他 JAX 变换。如果未指定,将根据数据和索引数组的内容计算最佳 nse。如果指定的 nse 大于必要的数量,将使用标准填充值填充数据和索引数组。如果小于必要的数量,将从输出矩阵中删除数据元素。
返回:
BCOO 数组具有排序索引且无重复索引。
返回类型:
mat_out
jax.experimental.sparse.bcoo_todense
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_todense.html
jax.experimental.sparse.bcoo_todense(mat)
将批处理稀疏矩阵转换为稠密矩阵。
参数:
mat(BCOO)– BCOO 矩阵。
返回:
mat
的稠密版本。
返回类型:
mat_dense
jax.experimental.sparse.bcoo_transpose
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_transpose.html
jax.experimental.sparse.bcoo_transpose(mat, *, permutation)
转置 BCOO 格式的数组。
参数:
- mat (BCOO) – 一个 BCOO 格式的数组。
- permutation (Sequence[int]) – 一个元组、列表或 ndarray,其中包含对
mat
的轴进行排列的置换,顺序为批处理、稀疏和稠密维度。返回数组的第 i 个轴对应于mat
的编号为 permutation[i] 的轴。目前,转置置换不支持将批处理轴与非批处理轴混合,也不支持将稠密轴与非稠密轴混合。
返回:
BCOO 格式的数组。
返回类型:
BCOO
jax.experimental.jet 模块
Jet 是一个实验性模块,用于更高阶的自动微分,不依赖于重复的一阶自动微分。
如何?通过截断的泰勒多项式的传播。考虑一个函数 ( f = g \circ h ),某个点 ( x ) 和某个偏移 ( v )。一阶自动微分(如 jax.jvp()
)从对 ((h(x), \partial h(x)[v])) 的计算得到对 ((f(x), \partial f(x)[v])) 的计算。
jet()
实现了更高阶的类似方法:给定元组
((h_0, … h_K) := (h(x), \partial h(x)[v], \partial² h(x)[v, v], …, \partial^K h(x)[v,…,v])),
代表在 ( x ) 处 ( h ) 的 ( K ) 阶泰勒近似,jet()
返回在 ( x ) 处 ( f ) 的 ( K ) 阶泰勒近似,
((f_0, …, f_K) := (f(x), \partial f(x)[v], \partial² f(x)[v, v], …, \partial^K f(x)[v,…,v])).
更具体地说,jet()
计算
[f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))]
因此可用于 ( f ) 的高阶自动微分。详细内容请参见 这些注释。
注
通过贡献 优秀的原始规则 来改进 jet()
。
API
jax.experimental.jet.jet(fun, primals, series)
泰勒模式高阶自动微分。
参数:
- fun – 要进行微分的函数。其参数应为数组、标量或标准 Python 容器中的数组或标量。应返回一个数组、标量或标准 Python 容器中的数组或标量。
- primals – 应评估
fun
泰勒近似值的原始值。应该是参数的元组或列表,并且其长度应与fun
的位置参数数量相等。 - 系列 – 更高阶的泰勒级数系数。原始数据和系列数据组成了一个截断的泰勒多项式。应该是一个元组或列表,其长度决定了截断的泰勒多项式的阶数。
返回:
一个 (primals_out, series_out)
对,其中 primals_out
是 fun(*primals)
的值,primals_out
和 series_out
一起构成了 ( f(h(\cdot)) ) 的截断泰勒多项式。primals_out
的值具有与 primals
相同的 Python 树结构,series_out
的值具有与 series
相同的 Python 树结构。
例如:
>>> import jax >>> import jax.numpy as np
考虑函数 ( h(z) = z³ ),( x = 0.5 ),和前几个泰勒系数 ( h_0=x³ ),( h_1=3x² ),( h_2=6x )。让 ( f(y) = \sin(y) )。
>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5 >>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args)
jet()
根据法阿·迪布鲁诺公式返回 ( f(h(z)) = \sin(z³) ) 的泰勒系数:
>>> f0, (f1, f2) = jet(f, (h0,), ((h1, h2),)) >>> print(f0, f(h0)) 0.12467473 0.12467473
>>> print(f1, df(h0) * h1) 0.7441479 0.74414825
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2) 2.9064622 2.9064634
jax.experimental.custom_partitioning 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html
API
jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())
在 XLA 图中插入一个 CustomCallOp,并使用自定义的 SPMD 降低规则。
@custom_partitioning def f(*args): return ... def propagate_user_sharding(mesh, user_shape): '''Update the sharding of the op from a user's shape.sharding.''' user_sharding = jax.tree.map(lambda x: x.sharding, user_shape) def partition(mesh, arg_shapes, result_shape): def lower_fn(*args): ... builds computation on per-device shapes ... result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) # result_sharding and arg_shardings may optionally be modified and the # partitioner will insert collectives to reshape. return mesh, lower_fn, result_sharding, arg_shardings def infer_sharding_from_operands(mesh, arg_shapes, shape): '''Compute the result sharding from the sharding of the operands.''' arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands)
def_partition
的参数如下:
propagate_user_sharding
:一个可调用对象,接受用户(在 DAG 中)的分片并返回一个新的 NamedSharding 的建议。默认实现只是返回建议的分片。partition
:一个可调用对象,接受 SPMD 建议的分片形状和分片规格,并返回网格、每个分片的降低函数以及最终的输入和输出分片规格(SPMD 分片器将重新分片输入以匹配)。返回网格以允许在未提供网格时配置集体的 axis_names。infer_sharding_from_operands
:一个可调用对象,从每个参数选择的NamedSharding
中计算输出的NamedSharding
。decode_shardings
:当设置为 True 时,如果可能,从输入中转换pyGSPMDSharding``s to ``NamedSharding
。如果用户未提供上下文网格,则可能无法执行此操作。
可以使用 static_argnums 将位置参数指定为静态参数。JAX 使用 inspect.signature(fun)
来解析这些位置参数。
示例
例如,假设我们想增强现有的 jax.numpy.fft.fft
。该函数计算 N 维输入沿最后一个维度的离散 Fourier 变换,并且在前 N-1 维度上进行批处理。但是,默认情况下,它会忽略输入的分片并在所有设备上收集输入。然而,由于 jax.numpy.fft.fft
在前 N-1 维度上进行批处理,这是不必要的。我们将创建一个新的 my_fft
操作,它不会改变前 N-1 维度上的分片,并且仅在需要时沿最后一个维度收集输入。
import jax from jax.sharding import NamedSharding from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax.sharding import Mesh from jax.numpy.fft import fft import regex as re import numpy as np # Pattern to detect all-gather or dynamic-slice in the generated HLO _PATTERN = '(dynamic-slice|all-gather)' # For an N-D input, keeps sharding along the first N-1 dimensions # but replicate along the last dimension def supported_sharding(sharding, shape): rank = len(shape.shape) max_shared_dims = min(len(sharding.spec), rank-1) names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims)) return NamedSharding(sharding.mesh, P(*names)) def partition(mesh, arg_shapes, result_shape): result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return mesh, fft, supported_sharding(arg_shardings[0], arg_shapes[0]), (supported_sharding(arg_shardings[0], arg_shapes[0]),) def infer_sharding_from_operands(mesh, arg_shapes, result_shape): arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) return supported_sharding(arg_shardings[0], arg_shapes[0]) @custom_partitioning def my_fft(x): return fft(x) my_fft.def_partition( infer_sharding_from_operands=infer_sharding_from_operands, partition=partition)
现在创建一个沿第一个轴分片的二维数组,通过 my_fft
处理它,并注意它仍按预期进行分片,并且与 fft
的输出相同。但是,检查 HLO(使用 lower(x).compile().runtime_executable().hlo_modules()
)显示 my_fft
不创建任何全收集或动态切片,而 fft
则创建。
with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]] # jax.numpy.fft.fft [[-38.840824 +0.j -40.649452 +11.845365j ... -1.6937828 +0.8402481j 15.999859 -4.0156755j]]
由于 supported_sharding
中的逻辑,my_fft
也适用于一维数组。但是,在这种情况下,my_fft
的 HLO 显示动态切片,因为最后一个维度是计算 FFT 的维度,在计算之前需要在所有设备上复制。
with Mesh(np.array(jax.devices()), ('x',)): x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64) y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x) pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x')) pjit_fft = pjit(fft, in_shardings=P('x'), out_shardings=P('x')) print(pjit_my_fft(y)) print(pjit_fft(y)) # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None) # dynamic-slice or all-gather are present in the HLO for fft assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is not None)
# my_fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j] # jax.numpy.fft.fft [ 7.217285 +0.j -3012.4937 +4287.635j -405.83594 +3042.984j ... 1422.4502 +7271.4297j -405.84033 -3042.983j -3012.4963 -4287.6343j]
jax.experimental.multihost_utils 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.multihost_utils.html
用于跨多个主机同步和通信的实用程序。
多主机工具 API 参考
broadcast_one_to_all (in_tree[, is_source]) |
从源主机(默认为主机 0)向所有其他主机广播数据。 |
sync_global_devices (name) |
在所有主机/设备之间创建屏障。 |
process_allgather (in_tree[, tiled]) |
从各个进程收集数据。 |
assert_equal (in_tree[, fail_message]) |
验证所有主机具有相同的值树。 |
host_local_array_to_global_array (…) |
将主机本地值转换为全局分片的 jax.Array。 |
global_array_to_host_local_array (…) |
将全局 jax.Array 转换为主机本地 jax.Array。 |
jax.experimental.compilation_cache 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.compilation_cache.html
JAX 磁盘编译缓存。
API
jax.experimental.compilation_cache.compilation_cache.is_initialized()
已废弃。
返回缓存是否已启用。初始化可以延迟,因此不会检查初始化状态。该名称保留以确保向后兼容性。
返回类型:
jax.experimental.compilation_cache.compilation_cache.initialize_cache(path)
此 API 已废弃;请使用set_cache_dir
替代。
设置路径。为了生效,在调用get_executable_and_time()
和put_executable_and_time()
之前应该调用此方法。
返回类型:
无
jax.experimental.compilation_cache.compilation_cache.set_cache_dir(path)
设置持久化编译缓存目录。
调用此方法后,jit 编译的函数将保存到路径中,因此如果进程重新启动或再次运行,则无需重新编译。这也告诉 Jax 在编译之前从哪里查找已编译的函数。
返回类型:
无
jax.experimental.compilation_cache.compilation_cache.reset_cache()
返回到原始未初始化状态。
返回类型:
无
jax.experimental.key_reuse 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html
实验性密钥重用检查
此模块包含用于检测 JAX 程序中随机密钥重用的实验性功能。它正在积极开发中,并且这里的 API 可能会发生变化。下面的使用需要 JAX 版本 0.4.26 或更新版本。
可以通过 jax_debug_key_reuse
配置启用密钥重用检查。全局设置如下:
>>> jax.config.update('jax_debug_key_reuse', True)
或者可以通过 jax.debug_key_reuse()
上下文管理器在本地启用。启用后,使用相同的密钥两次将导致 KeyReuseError
:
>>> import jax >>> with jax.debug_key_reuse(True): ... key = jax.random.key(0) ... val1 = jax.random.normal(key) ... val2 = jax.random.normal(key) Traceback (most recent call last): ... KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0
目前密钥重用检查器处于实验阶段,但未来我们可能会默认启用它。
jax.experimental.mesh_utils 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html
用于构建设备网格的实用工具。
API
create_device_mesh (mesh_shape[, devices, …]) |
为 jax.sharding.Mesh 创建一个高性能的设备网格。 |
create_hybrid_device_mesh (mesh_shape, …[, …]) |
创建一个用于混合(例如 ICI 和 DCN)并行性的设备网格。 |
jax.experimental.serialize_executable 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.serialize_executable.html
为预编译二进制文件提供了 Pickling 支持。
API
serialize (compiled) |
序列化编译后的二进制文件。 |
deserialize_and_load (serialized, in_tree, …) |
从序列化的可执行文件构建一个 jax.stages.Compiled 对象。 |
jax.experimental.shard_map 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.shard_map.html
API
shard_map (f, mesh, in_specs, out_specs[, …]) |
将一个函数映射到数据的分片上。 |
jax.lib 模块
jax.lib 包是一组内部工具和类型,用于连接 JAX 的 Python 前端和其 XLA 后端。
jax.lib.xla_bridge
default_backend () |
返回默认 XLA 后端的平台名称。 |
get_backend ([platform]) |
|
get_compile_options (num_replicas, num_partitions) |
返回用于编译的选项,从标志值派生而来。 |
jax.lib.xla_client
JAX 中文文档(十六)(2)https://developer.aliyun.com/article/1559727