JAX 中文文档(十六)(1)

简介: JAX 中文文档(十六)


原文:jax.readthedocs.io/en/latest/

jax.experimental.sparse.bcoo_multiply_dense

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_multiply_dense.html

jax.experimental.sparse.bcoo_multiply_dense(sp_mat, v)

稀疏数组和稠密数组之间的逐元素乘法。

参数:

  • lhs – 一个 BCOO 格式的数组。
  • rhs – 一个 ndarray
  • sp_matBCOO
  • vArray

返回:

包含结果的 ndarray。

返回类型:

Array

jax.experimental.sparse.bcoo_multiply_sparse

jax.experimental.sparse.bcoo_multiply_sparse

jax.experimental.sparse.bcoo_multiply_sparse(lhs, rhs)

两个稀疏数组的逐元素乘积。

参数:

  • lhs (BCOO) – 一个 BCOO 格式的数组。
  • rhs (BCOO) – 一个 BCOO 格式的数组。

返回值:

包含结果的 BCOO 格式数组。

返回类型:

BCOO

jax.experimental.sparse.bcoo_update_layout

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_update_layout.html

jax.experimental.sparse.bcoo_update_layout(mat, *, n_batch=None, n_dense=None, on_inefficient='error')

更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。

在许多情况下,可以在不引入不必要的存储开销的情况下完成此操作。然而,增加 mat.n_batchmat.n_dense 将导致存储效率非常低下,许多零值都是显式存储的,除非新的批处理或密集维度的大小为 0 或 1。在这种情况下,bcoo_update_layout 将引发 SparseEfficiencyError。可以通过指定 on_inefficient 参数来消除此警告。

参数:

  • matBCOO) – BCOO 数组
  • n_batchint | None) – 可选参数(整数),输出矩阵中批处理维度的数量。如果为 None,则 n_batch = mat.n_batch。
  • n_denseint | None) – 可选参数(整数),输出矩阵中密集维度的数量。如果为 None,则 n_dense = mat.n_dense。
  • on_inefficientstr | None) – 可选参数(字符串),其中之一 ['error', 'warn', None]。指定在重新配置效率低下的情况下的行为。这被定义为结果表示的大小远大于输入表示的情况。

返回:

BCOO 数组

表示与输入相同的稀疏数组的 BCOO 数组,具有指定的布局。 mat_out.todense() 将与 mat.todense() 在适当的精度上匹配。

返回类型:

mat_out

jax.experimental.sparse.bcoo_reduce_sum

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_reduce_sum.html

jax.experimental.sparse.bcoo_reduce_sum(mat, *, axes)

对给定轴上的数组元素求和。

参数:

  • matBCOO) – 一个 BCOO 格式的数组。
  • shape – 目标数组的形状。
  • axesSequence[int]) – 包含mat上进行求和的轴的元组、列表或 ndarray。

返回:

包含结果的 BCOO 格式数组。

返回类型:

BCOO

jax.experimental.sparse.bcoo_reshape

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_reshape.html

jax.experimental.sparse.bcoo_reshape(mat, *, new_sizes, dimensions=None)

稀疏实现的{func}jax.lax.reshape

参数:

  • operand – 待重塑的 BCOO 数组。
  • new_sizes (Sequence[int]) – 指定结果形状的整数序列。最终数组的大小必须与输入的大小相匹配。这必须指定为批量、稀疏和密集维度不混合的形式。
  • dimensions (Sequence[int] | None) – 可选的整数序列,指定输入形状的排列顺序。如果指定,长度必须与operand.shape相匹配。此外,维度必须仅在 mat 的相似维度之间进行排列:批量、稀疏和密集维度不能混合排列。
  • mat (BCOO)

返回:

重塑后的数组。

返回类型:

输出

jax.experimental.sparse.bcoo_slice

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_slice.html

jax.experimental.sparse.bcoo_slice(mat, *, start_indices, limit_indices, strides=None)

{func}jax.lax.slice 的稀疏实现。

参数:

  • mat (BCOO) – 待重新形状的 BCOO 数组。
  • 起始索引 (Sequence[int]) – 长度为 mat.ndim 的整数序列,指定每个切片的起始索引。
  • 限制索引 (Sequence[int]) – 长度为 mat.ndim 的整数序列,指定每个切片的结束索引
  • 步幅 (Sequence[int] | None) – (未实现) 长度为 mat.ndim 的整数序列,指定每个切片的步幅

返回:

包含切片的 BCOO 数组。

返回类型:

输出

jax.experimental.sparse.bcoo_sort_indices

链接:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_sort_indices.html

jax.experimental.sparse.bcoo_sort_indices(mat)

排序一个 BCOO 数组的索引。

参数:

matBCOO)– BCOO 数组

返回:

带有已排序索引的 BCOO 数组。

返回类型:

mat_out

jax.experimental.sparse.bcoo_squeeze

jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_squeeze.html

jax.experimental.sparse.bcoo_squeeze(arr, *, dimensions)

{func}jax.lax.squeeze 的稀疏实现。

从数组中挤出任意数量的大小为 1 的维度。

参数:

  • arr (BCOO) – 要重新塑形的 BCOO 数组。
  • 维度 (Sequence[int]) – 指定要挤压的整数序列。

返回:

重新塑形的数组。

返回类型:

out

jax.experimental.sparse.bcoo_sum_duplicates

原文

jax.experimental.sparse.bcoo_sum_duplicates(mat, nse=None)

对 BCOO 数组内的重复索引求和,返回一个带有排序索引的数组。

参数:

  • mat (BCOO) – BCOO 数组
  • nse (int | None)  – 整数(可选)。输出矩阵中指定元素的数量。这必须指定以使 bcoo_sum_duplicates 兼容 JIT 和其他 JAX  变换。如果未指定,将根据数据和索引数组的内容计算最佳 nse。如果指定的 nse  大于必要的数量,将使用标准填充值填充数据和索引数组。如果小于必要的数量,将从输出矩阵中删除数据元素。

返回:

BCOO 数组具有排序索引且无重复索引。

返回类型:

mat_out

jax.experimental.sparse.bcoo_todense

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_todense.html

jax.experimental.sparse.bcoo_todense(mat)

将批处理稀疏矩阵转换为稠密矩阵。

参数:

matBCOO)– BCOO 矩阵。

返回:

mat 的稠密版本。

返回类型:

mat_dense

jax.experimental.sparse.bcoo_transpose

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_transpose.html

jax.experimental.sparse.bcoo_transpose(mat, *, permutation)

转置 BCOO 格式的数组。

参数:

  • mat (BCOO) – 一个 BCOO 格式的数组。
  • permutation (Sequence[int]) – 一个元组、列表或 ndarray,其中包含对 mat 的轴进行排列的置换,顺序为批处理、稀疏和稠密维度。返回数组的第 i 个轴对应于 mat 的编号为 permutation[i] 的轴。目前,转置置换不支持将批处理轴与非批处理轴混合,也不支持将稠密轴与非稠密轴混合。

返回:

BCOO 格式的数组。

返回类型:

BCOO

jax.experimental.jet 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.jet.html

Jet 是一个实验性模块,用于更高阶的自动微分,不依赖于重复的一阶自动微分。

如何?通过截断的泰勒多项式的传播。考虑一个函数 ( f = g \circ h ),某个点 ( x ) 和某个偏移 ( v )。一阶自动微分(如 jax.jvp())从对 ((h(x), \partial h(x)[v])) 的计算得到对 ((f(x), \partial f(x)[v])) 的计算。

jet() 实现了更高阶的类似方法:给定元组

((h_0, … h_K) := (h(x), \partial h(x)[v], \partial² h(x)[v, v], …, \partial^K h(x)[v,…,v])),

代表在 ( x ) 处 ( h ) 的 ( K ) 阶泰勒近似,jet() 返回在 ( x ) 处 ( f ) 的 ( K ) 阶泰勒近似,

((f_0, …, f_K) := (f(x), \partial f(x)[v], \partial² f(x)[v, v], …, \partial^K f(x)[v,…,v])).

更具体地说,jet() 计算

[f_0, (f_1, . . . , f_K) = \texttt{jet} (f, h_0, (h_1, . . . , h_K))]

因此可用于 ( f ) 的高阶自动微分。详细内容请参见 这些注释

通过贡献 优秀的原始规则 来改进 jet()

API

jax.experimental.jet.jet(fun, primals, series)

泰勒模式高阶自动微分。

参数:

  • fun – 要进行微分的函数。其参数应为数组、标量或标准 Python 容器中的数组或标量。应返回一个数组、标量或标准 Python 容器中的数组或标量。
  • primals – 应评估 fun 泰勒近似值的原始值。应该是参数的元组或列表,并且其长度应与 fun 的位置参数数量相等。
  • 系列 – 更高阶的泰勒级数系数。原始数据和系列数据组成了一个截断的泰勒多项式。应该是一个元组或列表,其长度决定了截断的泰勒多项式的阶数。

返回:

一个 (primals_out, series_out) 对,其中 primals_outfun(*primals) 的值,primals_outseries_out 一起构成了 ( f(h(\cdot)) ) 的截断泰勒多项式。primals_out 的值具有与 primals 相同的 Python 树结构,series_out 的值具有与 series 相同的 Python 树结构。

例如:

>>> import jax
>>> import jax.numpy as np 

考虑函数 ( h(z) = z³ ),( x = 0.5 ),和前几个泰勒系数 ( h_0=x³ ),( h_1=3x² ),( h_2=6x )。让 ( f(y) = \sin(y) )。

>>> h0, h1, h2 = 0.5**3., 3.*0.5**2., 6.*0.5
>>> f, df, ddf = np.sin, np.cos, lambda *args: -np.sin(*args) 

jet() 根据法阿·迪布鲁诺公式返回 ( f(h(z)) = \sin(z³) ) 的泰勒系数:

>>> f0, (f1, f2) =  jet(f, (h0,), ((h1, h2),))
>>> print(f0,  f(h0))
0.12467473 0.12467473 
>>> print(f1, df(h0) * h1)
0.7441479 0.74414825 
>>> print(f2, ddf(h0) * h1 ** 2 + df(h0) * h2)
2.9064622 2.9064634 

jax.experimental.custom_partitioning 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html

API

jax.experimental.custom_partitioning.custom_partitioning(fun, static_argnums=())

在 XLA 图中插入一个 CustomCallOp,并使用自定义的 SPMD 降低规则。

@custom_partitioning
def f(*args):
  return ...
def propagate_user_sharding(mesh, user_shape):
  '''Update the sharding of the op from a user's shape.sharding.'''
  user_sharding = jax.tree.map(lambda x: x.sharding, user_shape)
def partition(mesh, arg_shapes, result_shape):
  def lower_fn(*args):
    ... builds computation on per-device shapes ...
  result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
  arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
  # result_sharding and arg_shardings may optionally be modified and the
  # partitioner will insert collectives to reshape.
  return mesh, lower_fn, result_sharding, arg_shardings
def infer_sharding_from_operands(mesh, arg_shapes, shape):
  '''Compute the result sharding from the sharding of the operands.'''
  arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
f.def_partition(partition, propagate_user_sharding, infer_sharding_from_operands) 

def_partition 的参数如下:

  • propagate_user_sharding:一个可调用对象,接受用户(在 DAG 中)的分片并返回一个新的 NamedSharding 的建议。默认实现只是返回建议的分片。
  • partition:一个可调用对象,接受 SPMD 建议的分片形状和分片规格,并返回网格、每个分片的降低函数以及最终的输入和输出分片规格(SPMD 分片器将重新分片输入以匹配)。返回网格以允许在未提供网格时配置集体的 axis_names。
  • infer_sharding_from_operands:一个可调用对象,从每个参数选择的 NamedSharding 中计算输出的 NamedSharding
  • decode_shardings:当设置为 True 时,如果可能,从输入中转换 pyGSPMDSharding``s to ``NamedSharding。如果用户未提供上下文网格,则可能无法执行此操作。

可以使用 static_argnums 将位置参数指定为静态参数。JAX 使用 inspect.signature(fun) 来解析这些位置参数。

示例

例如,假设我们想增强现有的 jax.numpy.fft.fft。该函数计算 N 维输入沿最后一个维度的离散 Fourier 变换,并且在前 N-1 维度上进行批处理。但是,默认情况下,它会忽略输入的分片并在所有设备上收集输入。然而,由于 jax.numpy.fft.fft 在前 N-1 维度上进行批处理,这是不必要的。我们将创建一个新的 my_fft 操作,它不会改变前 N-1 维度上的分片,并且仅在需要时沿最后一个维度收集输入。

import jax
from jax.sharding import NamedSharding
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
from jax.numpy.fft import fft
import regex as re
import numpy as np
# Pattern to detect all-gather or dynamic-slice in the generated HLO
_PATTERN = '(dynamic-slice|all-gather)'
# For an N-D input, keeps sharding along the first N-1 dimensions
# but replicate along the last dimension
def supported_sharding(sharding, shape):
    rank = len(shape.shape)
    max_shared_dims = min(len(sharding.spec), rank-1)
    names = tuple(sharding.spec[:max_shared_dims]) + tuple(None for _ in range(rank - max_shared_dims))
    return NamedSharding(sharding.mesh, P(*names))
def partition(mesh, arg_shapes, result_shape):
    result_shardings = jax.tree.map(lambda x: x.sharding, result_shape)
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return mesh, fft,               supported_sharding(arg_shardings[0], arg_shapes[0]),               (supported_sharding(arg_shardings[0], arg_shapes[0]),)
def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
    arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
    return supported_sharding(arg_shardings[0], arg_shapes[0])
@custom_partitioning
def my_fft(x):
    return fft(x)
my_fft.def_partition(
    infer_sharding_from_operands=infer_sharding_from_operands,
    partition=partition) 

现在创建一个沿第一个轴分片的二维数组,通过 my_fft 处理它,并注意它仍按预期进行分片,并且与 fft 的输出相同。但是,检查 HLO(使用 lower(x).compile().runtime_executable().hlo_modules())显示 my_fft 不创建任何全收集或动态切片,而 fft 则创建。

with Mesh(np.array(jax.devices()), ('x',)):
  x = np.asarray(np.random.randn(32*1024, 1024), dtype=np.complex64)
  y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
  pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
  pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
  print(pjit_my_fft(y))
  print(pjit_fft(y))
  # dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D array
  assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
  # dynamic-slice or all-gather are present in the HLO for fft
  assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None) 
# my_fft
[[-38.840824   +0.j        -40.649452  +11.845365j
...
  -1.6937828  +0.8402481j  15.999859   -4.0156755j]]
# jax.numpy.fft.fft
[[-38.840824   +0.j        -40.649452  +11.845365j
  ...
  -1.6937828  +0.8402481j  15.999859   -4.0156755j]] 

由于 supported_sharding 中的逻辑,my_fft 也适用于一维数组。但是,在这种情况下,my_fft 的 HLO 显示动态切片,因为最后一个维度是计算 FFT 的维度,在计算之前需要在所有设备上复制。

with Mesh(np.array(jax.devices()), ('x',)):
  x = np.asarray(np.random.randn(32*1024*1024), dtype=np.complex64)
  y = pjit(lambda x: x, in_shardings=None, out_shardings=P('x'))(x)
  pjit_my_fft = pjit(my_fft, in_shardings=P('x'), out_shardings=P('x'))
  pjit_fft    = pjit(fft,    in_shardings=P('x'), out_shardings=P('x'))
  print(pjit_my_fft(y))
  print(pjit_fft(y))
  # dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D array
  assert(re.search(_PATTERN, pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string()) is None)
  # dynamic-slice or all-gather are present in the HLO for fft
  assert(re.search(_PATTERN, pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())    is not None) 
# my_fft
[    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
...  1422.4502  +7271.4297j  -405.84033 -3042.983j
-3012.4963  -4287.6343j]
# jax.numpy.fft.fft
[    7.217285   +0.j     -3012.4937  +4287.635j   -405.83594 +3042.984j
...  1422.4502  +7271.4297j  -405.84033 -3042.983j
-3012.4963  -4287.6343j] 

jax.experimental.multihost_utils 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.multihost_utils.html

用于跨多个主机同步和通信的实用程序。

多主机工具 API 参考

broadcast_one_to_all(in_tree[, is_source]) 从源主机(默认为主机 0)向所有其他主机广播数据。
sync_global_devices(name) 在所有主机/设备之间创建屏障。
process_allgather(in_tree[, tiled]) 从各个进程收集数据。
assert_equal(in_tree[, fail_message]) 验证所有主机具有相同的值树。
host_local_array_to_global_array(…) 将主机本地值转换为全局分片的 jax.Array。
global_array_to_host_local_array(…) 将全局 jax.Array 转换为主机本地 jax.Array。

jax.experimental.compilation_cache 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.compilation_cache.html

JAX 磁盘编译缓存。

API

jax.experimental.compilation_cache.compilation_cache.is_initialized()

已废弃。

返回缓存是否已启用。初始化可以延迟,因此不会检查初始化状态。该名称保留以确保向后兼容性。

返回类型:

bool

jax.experimental.compilation_cache.compilation_cache.initialize_cache(path)

此 API 已废弃;请使用set_cache_dir替代。

设置路径。为了生效,在调用get_executable_and_time()put_executable_and_time()之前应该调用此方法。

返回类型:

jax.experimental.compilation_cache.compilation_cache.set_cache_dir(path)

设置持久化编译缓存目录。

调用此方法后,jit 编译的函数将保存到路径中,因此如果进程重新启动或再次运行,则无需重新编译。这也告诉 Jax 在编译之前从哪里查找已编译的函数。

返回类型:

jax.experimental.compilation_cache.compilation_cache.reset_cache()

返回到原始未初始化状态。

返回类型:

jax.experimental.key_reuse 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html

实验性密钥重用检查

此模块包含用于检测 JAX 程序中随机密钥重用的实验性功能。它正在积极开发中,并且这里的 API 可能会发生变化。下面的使用需要 JAX 版本 0.4.26 或更新版本。

可以通过 jax_debug_key_reuse 配置启用密钥重用检查。全局设置如下:

>>> jax.config.update('jax_debug_key_reuse', True) 

或者可以通过 jax.debug_key_reuse() 上下文管理器在本地启用。启用后,使用相同的密钥两次将导致 KeyReuseError

>>> import jax
>>> with jax.debug_key_reuse(True):
...   key = jax.random.key(0)
...   val1 = jax.random.normal(key)
...   val2 = jax.random.normal(key)  
Traceback (most recent call last):
  ...
KeyReuseError: Previously-consumed key passed to jit-compiled function at index 0 

目前密钥重用检查器处于实验阶段,但未来我们可能会默认启用它。

jax.experimental.mesh_utils 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html

用于构建设备网格的实用工具。

API

create_device_mesh(mesh_shape[, devices, …]) 为 jax.sharding.Mesh 创建一个高性能的设备网格。
create_hybrid_device_mesh(mesh_shape, …[, …]) 创建一个用于混合(例如 ICI 和 DCN)并行性的设备网格。

jax.experimental.serialize_executable 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.serialize_executable.html

为预编译二进制文件提供了 Pickling 支持。

API

serialize(compiled) 序列化编译后的二进制文件。
deserialize_and_load(serialized, in_tree, …) 从序列化的可执行文件构建一个 jax.stages.Compiled 对象。

jax.experimental.shard_map 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.shard_map.html

API

shard_map(f, mesh, in_specs, out_specs[, …]) 将一个函数映射到数据的分片上。

jax.lib 模块

原文:jax.readthedocs.io/en/latest/jax.lib.html

jax.lib 包是一组内部工具和类型,用于连接 JAX 的 Python 前端和其 XLA 后端。

jax.lib.xla_bridge

default_backend() 返回默认 XLA 后端的平台名称。
get_backend([platform])
get_compile_options(num_replicas, num_partitions) 返回用于编译的选项,从标志值派生而来。

jax.lib.xla_client


JAX 中文文档(十六)(2)https://developer.aliyun.com/article/1559727

相关文章
|
3月前
|
并行计算 异构计算 索引
JAX 中文文档(十六)(4)
JAX 中文文档(十六)
28 2
|
3月前
|
并行计算 算法框架/工具 异构计算
JAX 中文文档(十六)(5)
JAX 中文文档(十六)
44 2
|
3月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
35 3
|
3月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
26 3
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(十六)(2)
JAX 中文文档(十六)
80 1
|
3月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
34 2
|
3月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
41 1
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(十六)(3)
JAX 中文文档(十六)
56 0
|
3月前
|
Python
JAX 中文文档(十四)(4)
JAX 中文文档(十四)
25 0
|
3月前
|
关系型数据库
JAX 中文文档(十四)(1)
JAX 中文文档(十四)
22 0