JAX 中文文档(十五)(1)

简介: JAX 中文文档(十五)


原文:jax.readthedocs.io/en/latest/

jax.tree 模块

原文:jax.readthedocs.io/en/latest/jax.tree.html

用于处理树形容器数据结构的实用工具。

jax.tree 命名空间包含了来自 jax.tree_util 的实用工具的别名。

功能列表

all(tree, *[, is_leaf]) 对树的所有叶子进行 all()操作。
flatten(tree[, is_leaf]) 将一个 pytree 扁平化。
leaves(tree[, is_leaf]) 获取一个 pytree 的叶子。
map(f, tree, *rest[, is_leaf]) 将一个多输入函数映射到 pytree 参数上,生成一个新的 pytree。
reduce() 对树的叶子进行 reduce()操作。
structure(tree[, is_leaf]) 获取一个 pytree 的 treedef。
transpose(outer_treedef, inner_treedef, …) 将具有树结构 (outer, inner) 的树转换为具有结构 (inner, outer) 的树。
unflatten(treedef, leaves) 根据 treedef 和叶子重构一个 pytree。

jax.tree_util 模块

原文:jax.readthedocs.io/en/latest/jax.tree_util.html

用于处理树状容器数据结构的实用工具。

该模块提供了一小组用于处理树状数据结构(例如嵌套元组、列表和字典)的实用函数。我们称这些结构为  pytrees。它们是树形的,因为它们是递归定义的(任何非 pytree 都是 pytree,即叶子,任何 pytree 的 pytrees  都是 pytree),并且可以递归地操作(映射操作不保留对象身份等价性,并且这些结构不能包含引用循环)。

被视为 pytree 节点的 Python 类型集合(例如可以映射而不是视为叶子的类型)是可扩展的。存在一个单一的模块级别的类型注册表,并且类层次结构被忽略。通过注册一个新的 pytree 节点类型,该类型实际上变得对此文件中的实用函数透明。

该模块的主要目的是支持用户定义的数据结构与 JAX 转换(例如 jit)之间的互操作性。这不是一个通用的树状数据结构处理库。

查看 JAX pytrees 注释以获取示例。

函数列表

Partial(func, *args, **kw) 在 pytrees 中工作的 functools.partial 的版本。
all_leaves(iterable[, is_leaf]) 测试给定可迭代对象中的所有元素是否都是叶子。
build_tree(treedef, xs) 从嵌套的可迭代结构构建一个 treedef。
register_dataclass(nodetype, data_fields, …) 扩展了在 pytrees 中被视为内部节点的类型集合。
register_pytree_node(nodetype, flatten_func, …) 扩展了在 pytrees 中被视为内部节点的类型集合。
register_pytree_node_class(cls) 扩展了在 pytrees 中被视为内部节点的类型集合。
register_pytree_with_keys(nodetype, …[, …]) 扩展了在 pytrees 中被视为内部节点的类型集合。
register_pytree_with_keys_class(cls) 扩展了在 pytrees 中被视为内部节点的类型集合。
register_static(cls) 将 cls 注册为没有叶子的 pytree。
tree_flatten_with_path(tree[, is_leaf]) tree_flatten一样展平 pytree,但还返回每个叶子的键路径。
tree_leaves_with_path(tree[, is_leaf]) 获取类似tree_leaves的 pytree 的叶子,并返回每个叶子的键路径。
tree_map_with_path(f, tree, *rest[, is_leaf]) 对 pytree 键路径和参数执行多输入函数映射,生成新的 pytree。
treedef_children(treedef) 返回直接子节点的 treedef 列表。
treedef_is_leaf(treedef) 如果 treedef 表示叶子,则返回 True。
treedef_tuple(treedefs) 从子 treedefs 的可迭代对象制作一个元组 treedef。
keystr(keys) 辅助函数,用于漂亮地打印键的元组。

传统 API

现在通过jax.tree访问这些 API。

tree_all(tree, *[, is_leaf]) jax.tree.all()的别名。
tree_flatten(tree[, is_leaf]) jax.tree.flatten()的别名。
tree_leaves(tree[, is_leaf]) jax.tree.leaves()的别名。
tree_map(f, tree, *rest[, is_leaf]) jax.tree.map()的别名。
tree_reduce(function, tree[, initializer, …]) jax.tree.reduce()的别名。
tree_structure(tree[, is_leaf]) jax.tree.structure()的别名。
tree_transpose(outer_treedef, inner_treedef, …) jax.tree.transpose()的别名。
tree_unflatten(treedef, leaves) jax.tree.unflatten()的别名。

jax.typing 模块

原文:jax.readthedocs.io/en/latest/jax.typing.html

JAX 类型注解模块是 JAX 特定静态类型注解的存放地。这个子模块仍在开发中;要查看这里导出的类型背后的提案,请参阅jax.readthedocs.io/en/latest/jep/12049-type-annotations.html

当前可用的类型包括:

  • jax.Array: 适用于任何 JAX 数组或跟踪器的注解(即 JAX 变换中的数组表示)。
  • jax.typing.ArrayLike: 适用于任何安全隐式转换为 JAX 数组的值;这包括 jax.Arraynumpy.ndarray,以及 Python 内置数值类型(例如intfloat 等)和 numpy 标量值(例如 numpy.int32numpy.float64 等)。
  • jax.typing.DTypeLike:  适用于可以转换为 JAX 兼容 dtype 的任何值;这包括字符串(例如 ‘float32’、‘int32’)、标量类型(例如  float、np.float32)、dtype(例如 np.dtype(‘float32’))、或具有 dtype 属性的对象(例如  jnp.float32、jnp.int32)。

我们可能在将来的版本中添加其他类型。

JAX 类型注解最佳实践

在公共 API 函数中注释 JAX 数组时,我们建议使用 ArrayLike 来标注数组输入,使用 Array 来标注数组输出。

例如,您的函数可能如下所示:

import numpy as np
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike
def my_function(x: ArrayLike) -> Array:
  # Runtime type validation, Python 3.10 or newer:
  if not isinstance(x, ArrayLike):
    raise TypeError(f"Expected arraylike input; got {x}")
  # Runtime type validation, any Python version:
  if not (isinstance(x, (np.ndarray, Array)) or np.isscalar(x)):
    raise TypeError(f"Expected arraylike input; got {x}")
  # Convert input to jax.Array:
  x_arr = jnp.asarray(x)
  # ... do some computation; JAX functions will return Array types:
  result = x_arr.sum(0) / x_arr.shape[0]
  # return an Array
  return result 

JAX 的大多数公共 API 遵循这种模式。特别需要注意的是,我们建议 JAX 函数不要接受序列,如listtuple,而应该接受数组,因为这样可以避免在像 jit() 这样的 JAX 变换中产生额外的开销,并且在类似批处理变换 vmap()jax.pmap() 中可能会表现出意外行为。更多信息,请参阅NumPy vs JAX 中的非数组输入

成员列表

ArrayLike 适用于 JAX 数组类似对象的类型注解。
DTypeLike 别名为str | type[Any] | dtype | SupportsDType

jax.export 模块

原文:jax.readthedocs.io/en/latest/jax.export.html

Exported(fun_name, in_tree, in_avals, …) 降低为 StableHLO 的 JAX 函数。
DisabledSafetyCheck(_impl) 应在(反)序列化时跳过的安全检查。

函数

export(fun_jit, *[, platforms, …]) 导出一个用于持久化序列化的 JAX 函数。
deserialize(blob) 反序列化一个已导出的对象。
minimum_supported_calling_convention_version int([x]) -> integer int(x, base=10) -> integer
maximum_supported_calling_convention_version int([x]) -> integer int(x, base=10) -> integer
default_export_platform() 获取默认的导出平台。

与形状多态性相关的函数

symbolic_shape(shape_spec, *[, constraints, …]) 从字符串表示中构建一个符号形状。
symbolic_args_specs(args, shapes_specs[, …]) 为导出构建一个 jax.ShapeDtypeSpec 参数规范的 pytree。
is_symbolic_dim§ 检查一个维度是否是符号维度。
SymbolicScope([constraints_str]) 标识用于符号表达式的作用域。

常量

jax.export.minimum_supported_serialization_version

最小支持的序列化版本;参见调用约定版本。

jax.export.maximum_supported_serialization_version

最大支持的序列化版本;参见调用约定版本。

jax.extend 模块

原文:jax.readthedocs.io/en/latest/jax.extend.html

JAX 扩展模块。

jax.extend 包提供了访问 JAX 内部机制的模块。参见 JEP #15856

API 政策

公共 API 不同,这个包在发布版本之间 不提供兼容性保证。突破性变更将通过 JAX 项目变更日志 进行公告。

模块

  • jax.extend.ffi 模块
  • jax.extend.linear_util 模块
  • jax.extend.mlir 模块
  • jax.extend.random 模块


JAX 中文文档(十五)(2)https://developer.aliyun.com/article/1559768

相关文章
|
4月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
53 3
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
40 3
|
4月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
51 2
|
4月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
17 1
|
4月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
33 1
|
4月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
55 1
|
4月前
|
并行计算 API 异构计算
JAX 中文文档(十六)(2)
JAX 中文文档(十六)
117 1
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(2)
JAX 中文文档(十五)
36 0
|
4月前
|
TensorFlow API 算法框架/工具
JAX 中文文档(十五)(3)
JAX 中文文档(十五)
29 0
|
4月前
|
资源调度 算法 安全
JAX 中文文档(十四)(3)
JAX 中文文档(十四)
45 0