JAX 中文文档(十五)(1)https://developer.aliyun.com/article/1559766
jax.extend.ffi 模块
ffi_lowering (call_target_name, *[, …]) |
构建一个外部函数接口(FFI)目标的降低规则。 |
pycapsule (funcptr) |
将一个 ctypes 函数指针包装在 PyCapsule 中。 |
jax.extend.linear_util 模块
StoreException |
|
WrappedFun (f, transforms, stores, params, …) |
表示要应用转换的函数 f。 |
cache (call, *[, explain]) |
用于将 WrappedFun 作为第一个参数的函数的记忆化装饰器。 |
merge_linear_aux (aux1, aux2) |
|
transformation |
向 WrappedFun 添加一个转换。 |
transformation_with_aux |
向 WrappedFun 添加一个带有辅助输出的转换。 |
wrap_init (f[, params]) |
将函数 f 包装为 WrappedFun,适用于转换。 |
jax.extend.mlir 模块
jax.extend.random 模块
define_prng_impl (*, key_shape, seed, split, …) |
|
seed_with_impl (impl, seed) |
|
threefry2x32_p |
|
threefry_2x32 (keypair, count) |
应用 Threefry 2x32 哈希函数。 |
threefry_prng_impl |
指定 PRNG 密钥形状和操作。 |
rbg_prng_impl |
指定 PRNG 密钥形状和操作。 |
unsafe_rbg_prng_impl |
指定 PRNG 密钥形状和操作。 |
jax.example_libraries 模块
JAX 提供了一些小型的实验性机器学习库。这些库一部分提供工具,另一部分作为使用 JAX 构建此类库的示例。每个库的源代码行数不超过 300 行,因此请查看并根据需要进行调整!
注意
每个小型库的目的是灵感,而非规范。
为了达到这个目的,最好保持它们的代码示例简洁;因此,我们通常不会合并添加新功能的 PR。相反,请将您可爱的拉取请求和设计想法发送到更完整的库,如Haiku或Flax。
jax.example_libraries.optimizers
模块jax.example_libraries.stax
模块
jax.example_libraries.optimizers 模块
原文:
jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html
JAX 中如何编写优化器的示例。
您可能不想导入此模块!此库中的优化器仅供示例使用。如果您正在寻找功能完善的优化器库,两个不错的选择是 JAXopt 和 Optax。
此模块包含一些方便的优化器定义,特别是初始化和更新函数,可用于 ndarray 或任意嵌套的 tuple/list/dict 的 ndarray。
优化器被建模为一个 (init_fun, update_fun, get_params)
函数三元组,其中组件函数具有以下签名:
init_fun(params) Args: params: pytree representing the initial parameters. Returns: A pytree representing the initial optimizer state, which includes the initial parameters and may also include auxiliary values like initial momentum. The optimizer state pytree structure generally differs from that of `params`.
update_fun(step, grads, opt_state) Args: step: integer representing the step index. grads: a pytree with the same structure as `get_params(opt_state)` representing the gradients to be used in updating the optimizer state. opt_state: a pytree representing the optimizer state to be updated. Returns: A pytree with the same structure as the `opt_state` argument representing the updated optimizer state.
get_params(opt_state) Args: opt_state: pytree representing an optimizer state. Returns: A pytree representing the parameters extracted from `opt_state`, such that the invariant `params == get_params(init_fun(params))` holds true.
注意,优化器实现在 opt_state 的形式上具有很大的灵活性:它只需是 JaxTypes 的 pytree(以便可以将其传递给 api.py 中定义的 JAX 变换),并且它必须可以被 update_fun 和 get_params 消耗。
示例用法:
opt_init, opt_update, get_params = optimizers.sgd(learning_rate) opt_state = opt_init(params) def step(step, opt_state): value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state)) opt_state = opt_update(step, grads, opt_state) return value, opt_state for i in range(num_steps): value, opt_state = step(i, opt_state)
class jax.example_libraries.optimizers.JoinPoint(subtree)
Bases: object
标记了两个连接(嵌套)的 pytree 之间的边界。
class jax.example_libraries.optimizers.Optimizer(init_fn, update_fn, params_fn)
Bases: NamedTuple
参数:
- init_fn (Callable[**[Any]**, OptimizerState*]*)
- update_fn (Callable[**[int, Any, OptimizerState*]*,* OptimizerState]*)
- params_fn (Callable[[OptimizerState]*,* Any])
init_fn: Callable[[Any], OptimizerState]
字段 0 的别名
params_fn: Callable[[OptimizerState], Any]
字段 2 的别名
update_fn: Callable[[int, Any, OptimizerState], OptimizerState]
字段 1 的别名
class jax.example_libraries.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)
Bases: tuple
packed_state
字段 0 的别名
subtree_defs
字段 2 的别名
tree_def
字段 1 的别名
jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)
构建 Adagrad 的优化器三元组。
适应性次梯度方法用于在线学习和随机优化:www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf
参数:
- step_size – 正标量,或者将迭代索引映射到正标量的可调用对象的步长表达式。
- momentum – 可选,用于动量的正标量值
返回:
一个 (init_fun, update_fun, get_params) 三元组。
jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)
构建 Adam 的优化器三元组。
参数:
- step_size – 正的标量,或者一个可调用对象,表示将迭代索引映射到正的标量的步长计划。
- b1 – 可选,一个正的标量值,用于 beta_1,第一个时刻估计的指数衰减率(默认为 0.9)。
- b2 – 可选,一个正的标量值,用于 beta_2,第二个时刻估计的指数衰减率(默认为 0.999)。
- eps – 可选,一个正的标量值,用于 epsilon,即数值稳定性的小常数(默认为 1e-8)。
返回:
一个 (init_fun, update_fun, get_params) 三元组。
jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)
为 AdaMax(基于无穷范数的 Adam 变体)构造优化器三元组。
参数:
- step_size – 正的标量,或者一个可调用对象,表示将迭代索引映射到正的标量的步长计划。
- b1 – 可选,一个正的标量值,用于 beta_1,第一个时刻估计的指数衰减率(默认为 0.9)。
- b2 – 可选,一个正的标量值,用于 beta_2,第二个时刻估计的指数衰减率(默认为 0.999)。
- eps – 可选,一个正的标量值,用于 epsilon,即数值稳定性的小常数(默认为 1e-8)。
返回:
一个 (init_fun, update_fun, get_params) 三元组。
jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)
将存储为 pytree 结构的梯度裁剪到最大范数 max_norm。
jax.example_libraries.optimizers.constant(step_size)
返回类型:
jax.example_libraries.optimizers.exponential_decay(step_size, decay_steps, decay_rate)
jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)
jax.example_libraries.optimizers.l2_norm(tree)
计算一个 pytree 结构的数组的 l2 范数。适用于权重衰减。
jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)
参数:
scalar_or_schedule (float | Callable[**[int]**, float])
返回类型:
jax.example_libraries.optimizers.momentum(step_size, mass)
为带动量的 SGD 构造优化器三元组。
参数:
- step_size (Callable[**[int]**, float]) – 正的标量,或者一个可调用对象,表示将迭代索引映射到正的标量的步长计划。
- mass (float) – 正的标量,表示动量系数。
返回:
一个 (init_fun, update_fun, get_params) 三元组。
jax.example_libraries.optimizers.nesterov(step_size, mass)
为带有 Nesterov 动量的 SGD 构建优化器三元组。
参数:
返回:
一个(init_fun, update_fun, get_params)三元组。
jax.example_libraries.optimizers.optimizer(opt_maker)
装饰器,使定义为数组的优化器通用于容器。
使用此装饰器,您可以编写只对单个数组操作的 init、update 和 get_params 函数,并将它们转换为对参数 pytrees 进行操作的相应函数。有关示例,请参见 optimizers.py 中定义的优化器。
参数:
opt_maker(Callable[[…], tuple[Callable[**[Any]**, Any]**, Callable[**[int, Any, Any]**, Any]**, Callable[**[Any]**, Any]]]) –
返回一个返回(init_fun, update_fun, get_params)
函数三元组的函数,该函数可能仅适用于 ndarrays,如
init_fun :: ndarray -> OptStatePytree ndarray update_fun :: OptStatePytree ndarray -> OptStatePytree ndarray get_params :: OptStatePytree ndarray -> ndarray
返回:
一个(init_fun, update_fun, get_params)
函数三元组,这些函数按照任意 pytrees 进行操作,如
init_fun :: ParameterPytree ndarray -> OptimizerState update_fun :: OptimizerState -> OptimizerState get_params :: OptimizerState -> ParameterPytree ndarray
返回函数使用的 OptimizerState pytree 类型与ParameterPytree (OptStatePytree ndarray)
相同,但可能出于性能考虑将状态存储为部分展平的数据结构。
返回类型:
Callable[[…], Optimizer]
jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)
将标记的 pytree 转换为 OptimizerState。
unpack_optimizer_state 的逆操作。将一个带有 JoinPoints 的标记 pytree(其外部 pytree 的叶子表示为 JoinPoints)转换回一个 OptimizerState。这个函数用于在反序列化优化器状态时很有用。
参数:
marked_pytree – 一个包含 JoinPoint 叶子的 pytree,其保持更多 pytree。
返回:
输入参数的等效 OptimizerState。
jax.example_libraries.optimizers.piecewise_constant(boundaries, values)
参数:
jax.example_libraries.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)
jax.example_libraries.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)
为 RMSProp 构造优化器三元组。
参数:
step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。gamma:衰减参数。eps:Epsilon 参数。
返回:
一个(init_fun, update_fun, get_params)三元组。
jax.example_libraries.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)
为带动量的 RMSProp 构造优化器三元组。
这个优化器与 rmsprop 优化器分开,因为它需要跟踪额外的参数。
参数:
- step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。
- gamma – 衰减参数。
- eps – Epsilon 参数。
- momentum – 动量参数。
返回:
一个(init_fun, update_fun, get_params)三元组。
jax.example_libraries.optimizers.sgd(step_size)
为随机梯度下降构造优化器三元组。
参数:
step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。
返回:
一个(init_fun, update_fun, get_params)三元组。
jax.example_libraries.optimizers.sm3(step_size, momentum=0.9)
为 SM3 构造优化器三元组。
大规模学习的内存高效自适应优化。arxiv.org/abs/1901.11150
参数:
- step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。
- momentum – 可选,动量的正标量值
返回:
一个(init_fun, update_fun, get_params)三元组。
jax.example_libraries.optimizers.unpack_optimizer_state(opt_state)
将一个 OptimizerState 转换为带有 JoinPoints 叶子的标记 pytree。
将一个 OptimizerState 转换为带有 JoinPoints 叶子的标记 pytree,以避免丢失信息。这个函数在序列化优化器状态时很有用。
参数:
opt_state – 一个 OptimizerState
返回:
一个带有 JoinPoint 叶子的 pytree,其包含第二级 pytree。
jax.example_libraries.stax 模块
原文:
jax.readthedocs.io/en/latest/jax.example_libraries.stax.html
Stax 是一个从头开始的小而灵活的神经网络规范库。
您可能不想导入此模块!Stax 仅用作示例库。对于 JAX,还有许多其他功能更全面的神经网络库,包括来自 Google 的Flax 和来自 DeepMind 的Haiku。
jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)
用于创建池化层的层构造函数。
jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)
用于创建批量归一化层的层构造函数。
jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)
用于创建通用卷积层的层构造函数。
jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)
用于创建通用转置卷积层的层构造函数。
jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)
用于创建通用转置卷积层的层构造函数。
jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)
用于创建密集(全连接)层的层构造函数。
jax.example_libraries.stax.Dropout(rate, mode='train')
用于给定率创建丢弃层的层构造函数。
jax.example_libraries.stax.FanInConcat(axis=-1)
用于创建扇入连接层的层构造函数。
jax.example_libraries.stax.FanOut(num)
用于创建扇出层的层构造函数。
jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)
用于创建通用卷积层的层构造函数。
jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)
用于创建通用转置卷积层的层构造函数。
jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)
用于创建池化层的层构造函数。
jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)
用于创建池化层的层构造函数。
jax.example_libraries.stax.elementwise(fun, **fun_kwargs)
在其输入上逐元素应用标量函数的层。
jax.example_libraries.stax.parallel(*layers)
并行组合层的组合器。
此组合器生成的层通常与 FanOut 和 FanInSum 层一起使用。
参数:
*layers – 一个层序列,每个都是(init_fun, apply_fun)对。
返回:
表示给定层序列的并行组合的新层,即(init_fun, apply_fun)对。特别地,返回的层接受一个输入序列,并返回一个与参数层长度相同的输出序列。
jax.example_libraries.stax.serial(*layers)
串行组合层的组合器。
参数:
*layers – 一个层序列,每个都是(init_fun, apply_fun)对。
返回:
表示给定层序列的串行组合的新层,即(init_fun, apply_fun)对。
jax.example_libraries.stax.shape_dependent(make_layer)
延迟层构造对直到输入形状已知的组合器。
参数:
make_layer – 一个以输入形状(正整数元组)为参数的单参数函数,返回一个(init_fun, apply_fun)对。
返回:
表示与 make_layer 返回的相同层的新层,但其构造被延迟直到输入形状已知。
jax.experimental 模块
jax.experimental.optix
已迁移到其自己的 Python 包中 (deepmind/optax)。
jax.experimental.ann
已迁移到 jax.lax
。
实验性模块
jax.experimental.array_api
模块jax.experimental.checkify
模块jax.experimental.host_callback
模块jax.experimental.maps
模块jax.experimental.pjit
模块jax.experimental.sparse
模块jax.experimental.jet
模块jax.experimental.custom_partitioning
模块jax.experimental.multihost_utils
模块jax.experimental.compilation_cache
模块jax.experimental.key_reuse
模块jax.experimental.mesh_utils
模块jax.experimental.serialize_executable
模块jax.experimental.shard_map
模块
实验性 API
enable_x64 ([new_val]) |
实验性上下文管理器,临时启用 X64 模式。 |
disable_x64 () |
实验性上下文管理器,临时禁用 X64 模式。 |
jax.experimental.checkify.checkify (f[, errors]) |
在函数 f 中功能化检查调用,并可选地添加运行时错误检查。 |
jax.experimental.checkify.check (pred, msg, …) |
检查谓词,如果谓词为假,则添加带有消息的错误。 |
jax.experimental.checkify.check_error (error) |
如果 error 表示失败,则引发异常。 |
jax.experimental.array_api 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.array_api.html
此模块包括对 Python 数组 API 标准 的实验性 JAX 支持。目前对此的支持是实验性的,且尚未完全完成。
示例用法:
>>> from jax.experimental import array_api as xp >>> xp.__array_api_version__ '2023.12' >>> arr = xp.arange(1000) >>> arr.sum() Array(499500, dtype=int32)
xp
命名空间是 jax.numpy
的数组 API 兼容版本,并实现了大部分标准中列出的 API。
jax.experimental.checkify 模块
原文:
jax.readthedocs.io/en/latest/jax.experimental.checkify.html
API
checkify (f[, errors]) |
将检查调用功能化在函数 f 中,并可选择添加运行时错误检查。 |
check (pred, msg, *fmt_args, **fmt_kwargs) |
检查一个断言,如果断言为 False,则添加带有消息 msg 的错误。 |
check_error (error) |
如果 error 表示失败,则抛出异常。 |
Error (_pred, _code, _metadata, _payload) |
|
JaxRuntimeError |
|
user_checks |
frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象 |
nan_checks |
frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象 |
index_checks |
frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象 |
div_checks |
frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象 |
float_checks |
frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象 |
automatic_checks |
frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象 |
all_checks |
frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象 |
JAX 中文文档(十五)(3)https://developer.aliyun.com/article/1559769