jax.tree 模块
用于处理树形容器数据结构的实用工具。
jax.tree
命名空间包含了来自 jax.tree_util
的实用工具的别名。
功能列表
all (tree, *[, is_leaf]) |
对树的所有叶子进行 all()操作。 |
flatten (tree[, is_leaf]) |
将一个 pytree 扁平化。 |
leaves (tree[, is_leaf]) |
获取一个 pytree 的叶子。 |
map (f, tree, *rest[, is_leaf]) |
将一个多输入函数映射到 pytree 参数上,生成一个新的 pytree。 |
reduce () |
对树的叶子进行 reduce()操作。 |
structure (tree[, is_leaf]) |
获取一个 pytree 的 treedef。 |
transpose (outer_treedef, inner_treedef, …) |
将具有树结构 (outer, inner) 的树转换为具有结构 (inner, outer) 的树。 |
unflatten (treedef, leaves) |
根据 treedef 和叶子重构一个 pytree。 |
jax.tree_util 模块
用于处理树状容器数据结构的实用工具。
该模块提供了一小组用于处理树状数据结构(例如嵌套元组、列表和字典)的实用函数。我们称这些结构为 pytrees。它们是树形的,因为它们是递归定义的(任何非 pytree 都是 pytree,即叶子,任何 pytree 的 pytrees 都是 pytree),并且可以递归地操作(映射操作不保留对象身份等价性,并且这些结构不能包含引用循环)。
被视为 pytree 节点的 Python 类型集合(例如可以映射而不是视为叶子的类型)是可扩展的。存在一个单一的模块级别的类型注册表,并且类层次结构被忽略。通过注册一个新的 pytree 节点类型,该类型实际上变得对此文件中的实用函数透明。
该模块的主要目的是支持用户定义的数据结构与 JAX 转换(例如 jit)之间的互操作性。这不是一个通用的树状数据结构处理库。
查看 JAX pytrees 注释以获取示例。
函数列表
Partial (func, *args, **kw) |
在 pytrees 中工作的 functools.partial 的版本。 |
all_leaves (iterable[, is_leaf]) |
测试给定可迭代对象中的所有元素是否都是叶子。 |
build_tree (treedef, xs) |
从嵌套的可迭代结构构建一个 treedef。 |
register_dataclass (nodetype, data_fields, …) |
扩展了在 pytrees 中被视为内部节点的类型集合。 |
register_pytree_node (nodetype, flatten_func, …) |
扩展了在 pytrees 中被视为内部节点的类型集合。 |
register_pytree_node_class (cls) |
扩展了在 pytrees 中被视为内部节点的类型集合。 |
register_pytree_with_keys (nodetype, …[, …]) |
扩展了在 pytrees 中被视为内部节点的类型集合。 |
register_pytree_with_keys_class (cls) |
扩展了在 pytrees 中被视为内部节点的类型集合。 |
register_static (cls) |
将 cls 注册为没有叶子的 pytree。 |
tree_flatten_with_path (tree[, is_leaf]) |
像tree_flatten 一样展平 pytree,但还返回每个叶子的键路径。 |
tree_leaves_with_path (tree[, is_leaf]) |
获取类似tree_leaves 的 pytree 的叶子,并返回每个叶子的键路径。 |
tree_map_with_path (f, tree, *rest[, is_leaf]) |
对 pytree 键路径和参数执行多输入函数映射,生成新的 pytree。 |
treedef_children (treedef) |
返回直接子节点的 treedef 列表。 |
treedef_is_leaf (treedef) |
如果 treedef 表示叶子,则返回 True。 |
treedef_tuple (treedefs) |
从子 treedefs 的可迭代对象制作一个元组 treedef。 |
keystr (keys) |
辅助函数,用于漂亮地打印键的元组。 |
传统 API
现在通过jax.tree
访问这些 API。
tree_all (tree, *[, is_leaf]) |
jax.tree.all() 的别名。 |
tree_flatten (tree[, is_leaf]) |
jax.tree.flatten() 的别名。 |
tree_leaves (tree[, is_leaf]) |
jax.tree.leaves() 的别名。 |
tree_map (f, tree, *rest[, is_leaf]) |
jax.tree.map() 的别名。 |
tree_reduce (function, tree[, initializer, …]) |
jax.tree.reduce() 的别名。 |
tree_structure (tree[, is_leaf]) |
jax.tree.structure() 的别名。 |
tree_transpose (outer_treedef, inner_treedef, …) |
jax.tree.transpose() 的别名。 |
tree_unflatten (treedef, leaves) |
jax.tree.unflatten() 的别名。 |
jax.typing 模块
JAX 类型注解模块是 JAX 特定静态类型注解的存放地。这个子模块仍在开发中;要查看这里导出的类型背后的提案,请参阅jax.readthedocs.io/en/latest/jep/12049-type-annotations.html
。
当前可用的类型包括:
jax.Array
: 适用于任何 JAX 数组或跟踪器的注解(即 JAX 变换中的数组表示)。jax.typing.ArrayLike
: 适用于任何安全隐式转换为 JAX 数组的值;这包括jax.Array
、numpy.ndarray
,以及 Python 内置数值类型(例如int
、float
等)和 numpy 标量值(例如numpy.int32
、numpy.float64
等)。jax.typing.DTypeLike
: 适用于可以转换为 JAX 兼容 dtype 的任何值;这包括字符串(例如 ‘float32’、‘int32’)、标量类型(例如 float、np.float32)、dtype(例如 np.dtype(‘float32’))、或具有 dtype 属性的对象(例如 jnp.float32、jnp.int32)。
我们可能在将来的版本中添加其他类型。
JAX 类型注解最佳实践
在公共 API 函数中注释 JAX 数组时,我们建议使用 ArrayLike
来标注数组输入,使用 Array
来标注数组输出。
例如,您的函数可能如下所示:
import numpy as np import jax.numpy as jnp from jax import Array from jax.typing import ArrayLike def my_function(x: ArrayLike) -> Array: # Runtime type validation, Python 3.10 or newer: if not isinstance(x, ArrayLike): raise TypeError(f"Expected arraylike input; got {x}") # Runtime type validation, any Python version: if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)): raise TypeError(f"Expected arraylike input; got {x}") # Convert input to jax.Array: x_arr = jnp.asarray(x) # ... do some computation; JAX functions will return Array types: result = x_arr.sum(0) / x_arr.shape[0] # return an Array return result
JAX 的大多数公共 API 遵循这种模式。特别需要注意的是,我们建议 JAX 函数不要接受序列,如list
或tuple
,而应该接受数组,因为这样可以避免在像 jit()
这样的 JAX 变换中产生额外的开销,并且在类似批处理变换 vmap()
或 jax.pmap()
中可能会表现出意外行为。更多信息,请参阅NumPy vs JAX 中的非数组输入。
成员列表
jax.export 模块
类
Exported (fun_name, in_tree, in_avals, …) |
降低为 StableHLO 的 JAX 函数。 |
DisabledSafetyCheck (_impl) |
应在(反)序列化时跳过的安全检查。 |
函数
export (fun_jit, *[, platforms, …]) |
导出一个用于持久化序列化的 JAX 函数。 |
deserialize (blob) |
反序列化一个已导出的对象。 |
minimum_supported_calling_convention_version |
int([x]) -> integer int(x, base=10) -> integer |
maximum_supported_calling_convention_version |
int([x]) -> integer int(x, base=10) -> integer |
default_export_platform () |
获取默认的导出平台。 |
与形状多态性相关的函数
symbolic_shape (shape_spec, *[, constraints, …]) |
从字符串表示中构建一个符号形状。 |
symbolic_args_specs (args, shapes_specs[, …]) |
为导出构建一个 jax.ShapeDtypeSpec 参数规范的 pytree。 |
is_symbolic_dim § |
检查一个维度是否是符号维度。 |
SymbolicScope ([constraints_str]) |
标识用于符号表达式的作用域。 |
常量
jax.export.minimum_supported_serialization_version
最小支持的序列化版本;参见调用约定版本。
jax.export.maximum_supported_serialization_version
最大支持的序列化版本;参见调用约定版本。
jax.extend 模块
JAX 扩展模块。
jax.extend
包提供了访问 JAX 内部机制的模块。参见 JEP #15856。
API 政策
与 公共 API 不同,这个包在发布版本之间 不提供兼容性保证。突破性变更将通过 JAX 项目变更日志 进行公告。
模块
jax.extend.ffi
模块jax.extend.linear_util
模块jax.extend.mlir
模块jax.extend.random
模块
JAX 中文文档(十五)(2)https://developer.aliyun.com/article/1559768