JAX 中文文档(十五)(2)

简介: JAX 中文文档(十五)

JAX 中文文档(十五)(1)https://developer.aliyun.com/article/1559766

jax.extend.ffi 模块

原文:jax.readthedocs.io/en/latest/jax.extend.ffi.html

ffi_lowering(call_target_name, *[, …]) 构建一个外部函数接口(FFI)目标的降低规则。
pycapsule(funcptr) 将一个 ctypes 函数指针包装在 PyCapsule 中。

jax.extend.linear_util 模块

原文:jax.readthedocs.io/en/latest/jax.extend.linear_util.html

StoreException
WrappedFun(f, transforms, stores, params, …) 表示要应用转换的函数 f。
cache(call, *[, explain]) 用于将 WrappedFun 作为第一个参数的函数的记忆化装饰器。
merge_linear_aux(aux1, aux2)
transformation 向 WrappedFun 添加一个转换。
transformation_with_aux 向 WrappedFun 添加一个带有辅助输出的转换。
wrap_init(f[, params]) 将函数 f 包装为 WrappedFun,适用于转换。

jax.extend.mlir 模块

原文:jax.readthedocs.io/en/latest/jax.extend.mlir.html

jax.extend.random 模块

原文:jax.readthedocs.io/en/latest/jax.extend.random.html

define_prng_impl(*, key_shape, seed, split, …)
seed_with_impl(impl, seed)
threefry2x32_p
threefry_2x32(keypair, count) 应用 Threefry 2x32 哈希函数。
threefry_prng_impl 指定 PRNG 密钥形状和操作。
rbg_prng_impl 指定 PRNG 密钥形状和操作。
unsafe_rbg_prng_impl 指定 PRNG 密钥形状和操作。

jax.example_libraries 模块

原文:jax.readthedocs.io/en/latest/jax.example_libraries.html

JAX 提供了一些小型的实验性机器学习库。这些库一部分提供工具,另一部分作为使用 JAX 构建此类库的示例。每个库的源代码行数不超过 300 行,因此请查看并根据需要进行调整!

注意

每个小型库的目的是灵感,而非规范。

为了达到这个目的,最好保持它们的代码示例简洁;因此,我们通常不会合并添加新功能的 PR。相反,请将您可爱的拉取请求和设计想法发送到更完整的库,如HaikuFlax

  • jax.example_libraries.optimizers 模块
  • jax.example_libraries.stax 模块

jax.example_libraries.optimizers 模块

原文:jax.readthedocs.io/en/latest/jax.example_libraries.optimizers.html

JAX 中如何编写优化器的示例。

您可能不想导入此模块!此库中的优化器仅供示例使用。如果您正在寻找功能完善的优化器库,两个不错的选择是 JAXoptOptax

此模块包含一些方便的优化器定义,特别是初始化和更新函数,可用于 ndarray 或任意嵌套的 tuple/list/dict 的 ndarray。

优化器被建模为一个 (init_fun, update_fun, get_params) 函数三元组,其中组件函数具有以下签名:

init_fun(params)
Args:
  params: pytree representing the initial parameters.
Returns:
  A pytree representing the initial optimizer state, which includes the
  initial parameters and may also include auxiliary values like initial
  momentum. The optimizer state pytree structure generally differs from that
  of `params`. 
update_fun(step, grads, opt_state)
Args:
  step: integer representing the step index.
  grads: a pytree with the same structure as `get_params(opt_state)`
    representing the gradients to be used in updating the optimizer state.
  opt_state: a pytree representing the optimizer state to be updated.
Returns:
  A pytree with the same structure as the `opt_state` argument representing
  the updated optimizer state. 
get_params(opt_state)
Args:
  opt_state: pytree representing an optimizer state.
Returns:
  A pytree representing the parameters extracted from `opt_state`, such that
  the invariant `params == get_params(init_fun(params))` holds true. 

注意,优化器实现在 opt_state 的形式上具有很大的灵活性:它只需是 JaxTypes 的 pytree(以便可以将其传递给 api.py 中定义的 JAX 变换),并且它必须可以被 update_fun 和 get_params 消耗。

示例用法:

opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_state = opt_init(params)
def step(step, opt_state):
  value, grads = jax.value_and_grad(loss_fn)(get_params(opt_state))
  opt_state = opt_update(step, grads, opt_state)
  return value, opt_state
for i in range(num_steps):
  value, opt_state = step(i, opt_state) 
class jax.example_libraries.optimizers.JoinPoint(subtree)

Bases: object

标记了两个连接(嵌套)的 pytree 之间的边界。

class jax.example_libraries.optimizers.Optimizer(init_fn, update_fn, params_fn)

Bases: NamedTuple

参数:

init_fn: Callable[[Any], OptimizerState]

字段 0 的别名

params_fn: Callable[[OptimizerState], Any]

字段 2 的别名

update_fn: Callable[[int, Any, OptimizerState], OptimizerState]

字段 1 的别名

class jax.example_libraries.optimizers.OptimizerState(packed_state, tree_def, subtree_defs)

Bases: tuple

packed_state

字段 0 的别名

subtree_defs

字段 2 的别名

tree_def

字段 1 的别名

jax.example_libraries.optimizers.adagrad(step_size, momentum=0.9)

构建 Adagrad 的优化器三元组。

适应性次梯度方法用于在线学习和随机优化:www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf

参数:

  • step_size – 正标量,或者将迭代索引映射到正标量的可调用对象的步长表达式。
  • momentum – 可选,用于动量的正标量值

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.adam(step_size, b1=0.9, b2=0.999, eps=1e-08)

构建 Adam 的优化器三元组。

参数:

  • step_size – 正的标量,或者一个可调用对象,表示将迭代索引映射到正的标量的步长计划。
  • b1 – 可选,一个正的标量值,用于 beta_1,第一个时刻估计的指数衰减率(默认为 0.9)。
  • b2 – 可选,一个正的标量值,用于 beta_2,第二个时刻估计的指数衰减率(默认为 0.999)。
  • eps – 可选,一个正的标量值,用于 epsilon,即数值稳定性的小常数(默认为 1e-8)。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.adamax(step_size, b1=0.9, b2=0.999, eps=1e-08)

为 AdaMax(基于无穷范数的 Adam 变体)构造优化器三元组。

参数:

  • step_size – 正的标量,或者一个可调用对象,表示将迭代索引映射到正的标量的步长计划。
  • b1 – 可选,一个正的标量值,用于 beta_1,第一个时刻估计的指数衰减率(默认为 0.9)。
  • b2 – 可选,一个正的标量值,用于 beta_2,第二个时刻估计的指数衰减率(默认为 0.999)。
  • eps – 可选,一个正的标量值,用于 epsilon,即数值稳定性的小常数(默认为 1e-8)。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.clip_grads(grad_tree, max_norm)

将存储为 pytree 结构的梯度裁剪到最大范数 max_norm。

jax.example_libraries.optimizers.constant(step_size)

返回类型:

Callable[[int], float]

jax.example_libraries.optimizers.exponential_decay(step_size, decay_steps, decay_rate)
jax.example_libraries.optimizers.inverse_time_decay(step_size, decay_steps, decay_rate, staircase=False)
jax.example_libraries.optimizers.l2_norm(tree)

计算一个 pytree 结构的数组的 l2 范数。适用于权重衰减。

jax.example_libraries.optimizers.make_schedule(scalar_or_schedule)

参数:

scalar_or_schedule (float | Callable[**[int]**, float])

返回类型:

Callable[[int], float]

jax.example_libraries.optimizers.momentum(step_size, mass)

为带动量的 SGD 构造优化器三元组。

参数:

  • step_size (Callable[**[int]**, float]) – 正的标量,或者一个可调用对象,表示将迭代索引映射到正的标量的步长计划。
  • mass (float) – 正的标量,表示动量系数。

返回:

一个 (init_fun, update_fun, get_params) 三元组。

jax.example_libraries.optimizers.nesterov(step_size, mass)

为带有 Nesterov 动量的 SGD 构建优化器三元组。

参数:

  • step_sizeCallable[**[int]**, float]) – 正标量,或表示将迭代索引映射到正标量的步长计划的可调用对象。
  • massfloat) – 正标量,表示动量系数。

返回:

一个(init_fun, update_fun, get_params)三元组。

jax.example_libraries.optimizers.optimizer(opt_maker)

装饰器,使定义为数组的优化器通用于容器。

使用此装饰器,您可以编写只对单个数组操作的 init、update 和 get_params 函数,并将它们转换为对参数 pytrees 进行操作的相应函数。有关示例,请参见 optimizers.py 中定义的优化器。

参数:

opt_makerCallable[[], tuple[Callable[**[Any]**, Any]**, Callable[**[int, Any, Any]**, Any]**, Callable[**[Any]**, Any]]]) –

返回一个返回(init_fun, update_fun, get_params)函数三元组的函数,该函数可能仅适用于 ndarrays,如

init_fun  ::  ndarray  ->  OptStatePytree  ndarray
update_fun  ::  OptStatePytree  ndarray  ->  OptStatePytree  ndarray
get_params  ::  OptStatePytree  ndarray  ->  ndarray 

返回:

一个(init_fun, update_fun, get_params)函数三元组,这些函数按照任意 pytrees 进行操作,如

init_fun  ::  ParameterPytree  ndarray  ->  OptimizerState
update_fun  ::  OptimizerState  ->  OptimizerState
get_params  ::  OptimizerState  ->  ParameterPytree  ndarray 

返回函数使用的 OptimizerState pytree 类型与ParameterPytree (OptStatePytree ndarray)相同,但可能出于性能考虑将状态存储为部分展平的数据结构。

返回类型:

Callable[[…], Optimizer]

jax.example_libraries.optimizers.pack_optimizer_state(marked_pytree)

将标记的 pytree 转换为 OptimizerState。

unpack_optimizer_state 的逆操作。将一个带有 JoinPoints 的标记 pytree(其外部 pytree 的叶子表示为 JoinPoints)转换回一个 OptimizerState。这个函数用于在反序列化优化器状态时很有用。

参数:

marked_pytree – 一个包含 JoinPoint 叶子的 pytree,其保持更多 pytree。

返回:

输入参数的等效 OptimizerState。

jax.example_libraries.optimizers.piecewise_constant(boundaries, values)

参数:

jax.example_libraries.optimizers.polynomial_decay(step_size, decay_steps, final_step_size, power=1.0)
jax.example_libraries.optimizers.rmsprop(step_size, gamma=0.9, eps=1e-08)

为 RMSProp 构造优化器三元组。

参数:

step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。gamma:衰减参数。eps:Epsilon 参数。

返回:

一个(init_fun, update_fun, get_params)三元组。

jax.example_libraries.optimizers.rmsprop_momentum(step_size, gamma=0.9, eps=1e-08, momentum=0.9)

为带动量的 RMSProp 构造优化器三元组。

这个优化器与 rmsprop 优化器分开,因为它需要跟踪额外的参数。

参数:

  • step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。
  • gamma – 衰减参数。
  • eps – Epsilon 参数。
  • momentum – 动量参数。

返回:

一个(init_fun, update_fun, get_params)三元组。

jax.example_libraries.optimizers.sgd(step_size)

为随机梯度下降构造优化器三元组。

参数:

step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。

返回:

一个(init_fun, update_fun, get_params)三元组。

jax.example_libraries.optimizers.sm3(step_size, momentum=0.9)

为 SM3 构造优化器三元组。

大规模学习的内存高效自适应优化。arxiv.org/abs/1901.11150

参数:

  • step_size – 正标量,或者一个可调用函数,表示将迭代索引映射到正标量的步长计划。
  • momentum – 可选,动量的正标量值

返回:

一个(init_fun, update_fun, get_params)三元组。

jax.example_libraries.optimizers.unpack_optimizer_state(opt_state)

将一个 OptimizerState 转换为带有 JoinPoints 叶子的标记 pytree。

将一个 OptimizerState 转换为带有 JoinPoints 叶子的标记 pytree,以避免丢失信息。这个函数在序列化优化器状态时很有用。

参数:

opt_state – 一个 OptimizerState

返回:

一个带有 JoinPoint 叶子的 pytree,其包含第二级 pytree。

jax.example_libraries.stax 模块

原文:jax.readthedocs.io/en/latest/jax.example_libraries.stax.html

Stax 是一个从头开始的小而灵活的神经网络规范库。

您可能不想导入此模块!Stax 仅用作示例库。对于 JAX,还有许多其他功能更全面的神经网络库,包括来自 Google 的Flax 和来自 DeepMind 的Haiku

jax.example_libraries.stax.AvgPool(window_shape, strides=None, padding='VALID', spec=None)

用于创建池化层的层构造函数。

jax.example_libraries.stax.BatchNorm(axis=(0, 1, 2), epsilon=1e-05, center=True, scale=True, beta_init=<function zeros>, gamma_init=<function ones>)

用于创建批量归一化层的层构造函数。

jax.example_libraries.stax.Conv(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用卷积层的层构造函数。

jax.example_libraries.stax.Conv1DTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用转置卷积层的层构造函数。

jax.example_libraries.stax.ConvTranspose(out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用转置卷积层的层构造函数。

jax.example_libraries.stax.Dense(out_dim, W_init=<function variance_scaling.<locals>.init>, b_init=<function normal.<locals>.init>)

用于创建密集(全连接)层的层构造函数。

jax.example_libraries.stax.Dropout(rate, mode='train')

用于给定率创建丢弃层的层构造函数。

jax.example_libraries.stax.FanInConcat(axis=-1)

用于创建扇入连接层的层构造函数。

jax.example_libraries.stax.FanOut(num)

用于创建扇出层的层构造函数。

jax.example_libraries.stax.GeneralConv(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用卷积层的层构造函数。

jax.example_libraries.stax.GeneralConvTranspose(dimension_numbers, out_chan, filter_shape, strides=None, padding='VALID', W_init=None, b_init=<function normal.<locals>.init>)

用于创建通用转置卷积层的层构造函数。

jax.example_libraries.stax.MaxPool(window_shape, strides=None, padding='VALID', spec=None)

用于创建池化层的层构造函数。

jax.example_libraries.stax.SumPool(window_shape, strides=None, padding='VALID', spec=None)

用于创建池化层的层构造函数。

jax.example_libraries.stax.elementwise(fun, **fun_kwargs)

在其输入上逐元素应用标量函数的层。

jax.example_libraries.stax.parallel(*layers)

并行组合层的组合器。

此组合器生成的层通常与 FanOut 和 FanInSum 层一起使用。

参数:

*layers – 一个层序列,每个都是(init_fun, apply_fun)对。

返回:

表示给定层序列的并行组合的新层,即(init_fun, apply_fun)对。特别地,返回的层接受一个输入序列,并返回一个与参数层长度相同的输出序列。

jax.example_libraries.stax.serial(*layers)

串行组合层的组合器。

参数:

*layers – 一个层序列,每个都是(init_fun, apply_fun)对。

返回:

表示给定层序列的串行组合的新层,即(init_fun, apply_fun)对。

jax.example_libraries.stax.shape_dependent(make_layer)

延迟层构造对直到输入形状已知的组合器。

参数:

make_layer – 一个以输入形状(正整数元组)为参数的单参数函数,返回一个(init_fun, apply_fun)对。

返回:

表示与 make_layer 返回的相同层的新层,但其构造被延迟直到输入形状已知。

jax.experimental 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.html

jax.experimental.optix 已迁移到其自己的 Python 包中 (deepmind/optax)。

jax.experimental.ann 已迁移到 jax.lax

实验性模块

  • jax.experimental.array_api 模块
  • jax.experimental.checkify 模块
  • jax.experimental.host_callback 模块
  • jax.experimental.maps 模块
  • jax.experimental.pjit 模块
  • jax.experimental.sparse 模块
  • jax.experimental.jet 模块
  • jax.experimental.custom_partitioning 模块
  • jax.experimental.multihost_utils 模块
  • jax.experimental.compilation_cache 模块
  • jax.experimental.key_reuse 模块
  • jax.experimental.mesh_utils 模块
  • jax.experimental.serialize_executable 模块
  • jax.experimental.shard_map 模块

实验性 API

enable_x64([new_val]) 实验性上下文管理器,临时启用 X64 模式。
disable_x64() 实验性上下文管理器,临时禁用 X64 模式。
jax.experimental.checkify.checkify(f[, errors]) 在函数 f 中功能化检查调用,并可选地添加运行时错误检查。
jax.experimental.checkify.check(pred, msg, …) 检查谓词,如果谓词为假,则添加带有消息的错误。
jax.experimental.checkify.check_error(error) 如果 error 表示失败,则引发异常。

jax.experimental.array_api 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.array_api.html

此模块包括对 Python 数组 API 标准 的实验性 JAX 支持。目前对此的支持是实验性的,且尚未完全完成。

示例用法:

>>> from jax.experimental import array_api as xp
>>> xp.__array_api_version__
'2023.12'
>>> arr = xp.arange(1000)
>>> arr.sum()
Array(499500, dtype=int32) 

xp 命名空间是 jax.numpy 的数组 API 兼容版本,并实现了大部分标准中列出的 API。

jax.experimental.checkify 模块

原文:jax.readthedocs.io/en/latest/jax.experimental.checkify.html

API

checkify(f[, errors]) 将检查调用功能化在函数 f 中,并可选择添加运行时错误检查。
check(pred, msg, *fmt_args, **fmt_kwargs) 检查一个断言,如果断言为 False,则添加带有消息 msg 的错误。
check_error(error) 如果 error 表示失败,则抛出异常。
Error(_pred, _code, _metadata, _payload)
JaxRuntimeError
user_checks frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象
nan_checks frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象
index_checks frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象
div_checks frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象
float_checks frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象
automatic_checks frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象
all_checks frozenset() -> 空的 frozenset 对象 frozenset(iterable) -> frozenset 对象


JAX 中文文档(十五)(3)https://developer.aliyun.com/article/1559769

相关文章
|
3月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
26 3
|
3月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
34 3
|
3月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
34 2
|
3月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
40 1
|
3月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
14 1
|
3月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
23 1
|
3月前
|
并行计算 API 异构计算
JAX 中文文档(十六)(2)
JAX 中文文档(十六)
77 1
|
3月前
|
安全 API 网络架构
JAX 中文文档(十五)(1)
JAX 中文文档(十五)
37 0
|
3月前
|
TensorFlow API 算法框架/工具
JAX 中文文档(十五)(3)
JAX 中文文档(十五)
22 0
|
3月前
|
Python
JAX 中文文档(十四)(4)
JAX 中文文档(十四)
25 0