JAX 中文文档(十四)(2)

简介: JAX 中文文档(十四)

JAX 中文文档(十四)(1)https://developer.aliyun.com/article/1559755

jax.lax 模块

原文:jax.readthedocs.io/en/latest/jax.lax.html

jax.lax 是支持诸如 jax.numpy 等库的基本操作的库。通常会定义转换规则,例如 JVP 和批处理规则,作为对 jax.lax 基元的转换。

许多基元都是等价于 XLA 操作的薄包装,详细描述请参阅XLA 操作语义文档。

在可能的情况下,优先使用诸如 jax.numpy 等库,而不是直接使用 jax.laxjax.numpy API 遵循 NumPy,因此比 jax.lax API 更稳定,更不易更改。

Operators

abs(x) 按元素绝对值:(|x|)。
acos(x) 按元素求反余弦:(\mathrm{acos}(x))。
acosh(x) 按元素求反双曲余弦:(\mathrm{acosh}(x))。
add(x, y) 按元素加法:(x + y)。
after_all(*operands) 合并一个或多个 XLA 令牌值。
approx_max_k(operand, k[, …]) 以近似方式返回 operand 的最大 k 值及其索引。
approx_min_k(operand, k[, …]) 以近似方式返回 operand 的最小 k 值及其索引。
argmax(operand, axis, index_dtype) 计算沿着 axis 的最大元素的索引。
argmin(operand, axis, index_dtype) 计算沿着 axis 的最小元素的索引。
asin(x) 按元素求反正弦:(\mathrm{asin}(x))。
asinh(x) 按元素求反双曲正弦:(\mathrm{asinh}(x))。
atan(x) 按元素求反正切:(\mathrm{atan}(x))。
atan2(x, y) 两个变量的按元素反正切:(\mathrm{atan}({x \over y}))。
atanh(x) 按元素求反双曲正切:(\mathrm{atanh}(x))。
batch_matmul(lhs, rhs[, precision]) 批量矩阵乘法。
bessel_i0e(x) 指数缩放修正贝塞尔函数 (0) 阶:(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x))
bessel_i1e(x) 指数缩放修正贝塞尔函数 (1) 阶:(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x))
betainc(a, b, x) 按元素的正则化不完全贝塔积分。
bitcast_convert_type(operand, new_dtype) 按元素位转换。
bitwise_and(x, y) 按位与运算:(x \wedge y)。
bitwise_not(x) 按位取反:(\neg x)。
bitwise_or(x, y) 按位或运算:(x \vee y)。
bitwise_xor(x, y) 按位异或运算:(x \oplus y)。
population_count(x) 按元素计算 popcount,即每个元素中设置的位数。
broadcast(operand, sizes) 广播数组,添加新的前导维度。
broadcast_in_dim(operand, shape, …) 包装 XLA 的 BroadcastInDim 操作符。
broadcast_shapes() 返回经过 NumPy 广播后的形状。
broadcast_to_rank(x, rank) 添加 1 的前导维度,使 x 的等级为 rank
broadcasted_iota(dtype, shape, dimension) iota的便捷封装器。
cbrt(x) 元素级立方根:(\sqrt[3]{x})。
ceil(x) 元素级向上取整:(\left\lceil x \right\rceil)。
clamp(min, x, max) 元素级 clamp 函数。
clz(x) 元素级计算前导零的个数。
collapse(operand, start_dimension[, …]) 将数组的维度折叠为单个维度。
complex(x, y) 元素级构造复数:(x + jy)。
concatenate(operands, dimension) 沿指定维度连接一系列数组。
conj(x) 元素级复数的共轭函数:(\overline{x})。
conv(lhs, rhs, window_strides, padding[, …]) conv_general_dilated的便捷封装器。
convert_element_type(operand, new_dtype) 元素级类型转换。
conv_dimension_numbers(lhs_shape, rhs_shape, …) 将卷积维度编号转换为 ConvDimensionNumbers
conv_general_dilated(lhs, rhs, …[, …]) 带有可选扩展的通用 n 维卷积运算符。
conv_general_dilated_local(lhs, rhs, …[, …]) 带有可选扩展的通用 n 维非共享卷积运算符。
conv_general_dilated_patches(lhs, …[, …]) 提取符合 conv_general_dilated 接受域的补丁。
conv_transpose(lhs, rhs, strides, padding[, …]) 计算 N 维卷积的“转置”的便捷封装器。
conv_with_general_padding(lhs, rhs, …[, …]) conv_general_dilated的便捷封装器。
cos(x) 元素级余弦函数:(\mathrm{cos}(x))。
cosh(x) 元素级双曲余弦函数:(\mathrm{cosh}(x))。
cumlogsumexp(operand[, axis, reverse]) 沿轴计算累积 logsumexp。
cummax(operand[, axis, reverse]) 沿轴计算累积最大值。
cummin(operand[, axis, reverse]) 沿轴计算累积最小值。
cumprod(operand[, axis, reverse]) 沿轴计算累积乘积。
cumsum(operand[, axis, reverse]) 沿轴计算累积和。
digamma(x) 元素级 digamma 函数:(\psi(x))。
div(x, y) 元素级除法:(x \over y)。
dot(lhs, rhs[, precision, …]) 向量/向量,矩阵/向量和矩阵/矩阵乘法。
dot_general(lhs, rhs, dimension_numbers[, …]) 通用的点积/收缩运算符。
dynamic_index_in_dim(operand, index[, axis, …]) dynamic_slice 的便捷封装,用于执行整数索引。
dynamic_slice(operand, start_indices, …) 封装了 XLA 的 DynamicSlice 操作符。
dynamic_slice_in_dim(operand, start_index, …) 方便地封装了应用于单个维度的 lax.dynamic_slice()
dynamic_update_index_in_dim(operand, update, …) 方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新大小为 1 的切片。
dynamic_update_slice(operand, update, …) 封装了 XLA 的 DynamicUpdateSlice 操作符。
dynamic_update_slice_in_dim(operand, update, …) 方便地封装了 dynamic_update_slice(),用于在单个 axis 中更新一个切片。
eq(x, y) 元素级相等:(x = y)。
erf(x) 元素级误差函数:(\mathrm{erf}(x))。
erfc(x) 元素级补充误差函数:(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x))。
erf_inv(x) 元素级反误差函数:(\mathrm{erf}^{-1}(x))。
exp(x) 元素级指数函数:(e^x)。
expand_dims(array, dimensions) 将任意数量的大小为 1 的维度插入到数组中。
expm1(x) 元素级运算 (e^{x} - 1)。
fft(x, fft_type, fft_lengths)
floor(x) 元素级向下取整:(\left\lfloor x \right\rfloor)。
full(shape, fill_value[, dtype, sharding]) 返回填充值为 fill_value 的形状数组。
full_like(x, fill_value[, dtype, shape, …]) 基于示例数组 x 创建类似于 np.full 的完整数组。
gather(operand, start_indices, …[, …]) Gather 操作符。
ge(x, y) 元素级大于或等于:(x \geq y)。
gt(x, y) 元素级大于:(x > y)。
igamma(a, x) 元素级正则化不完全 gamma 函数。
igammac(a, x) 元素级补充正则化不完全 gamma 函数。
imag(x) 提取复数的虚部:(\mathrm{Im}(x))。
index_in_dim(operand, index[, axis, keepdims]) 方便地封装了 lax.slice(),用于执行整数索引。
index_take(src, idxs, axes)
integer_pow(x, y) 元素级幂运算:(x^y),其中 (y) 是固定整数。
iota(dtype, size) 封装了 XLA 的 Iota 操作符。
is_finite(x) 元素级 (\mathrm{isfinite})。
le(x, y) 元素级小于或等于:(x \leq y)。
lgamma(x) 元素级对数 gamma 函数:(\mathrm{log}(\Gamma(x)))。
log(x) 元素级自然对数:(\mathrm{log}(x))。
log1p(x) 元素级 (\mathrm{log}(1 + x))。
logistic(x) 元素级 logistic(sigmoid)函数:(\frac{1}{1 + e^{-x}})。
lt(x, y) 元素级小于:(x < y)。
max(x, y) 元素级最大值:(\mathrm{max}(x, y))
min(x, y) 元素级最小值:(\mathrm{min}(x, y))
mul(x, y) 元素级乘法:(x \times y)。
ne(x, y) 按位不等于:(x \neq y)。
neg(x) 按位取负:(-x)。
nextafter(x1, x2) 返回 x1 在 x2 方向上的下一个可表示的值。
pad(operand, padding_value, padding_config) 对数组应用低、高和/或内部填充。
polygamma(m, x) 按位多次 gamma 函数:(\psi^{(m)}(x))。
population_count(x) 按位人口统计,统计每个元素中设置的位数。
pow(x, y) 按位幂运算:(x^y)。
random_gamma_grad(a, x) Gamma 分布导数的按位计算。
real(x) 按位提取实部:(\mathrm{Re}(x))。
reciprocal(x) 按位倒数:(1 \over x)。
reduce(operands, init_values, computation, …) 封装了 XLA 的 Reduce 运算符。
reduce_precision(operand, exponent_bits, …) 封装了 XLA 的 ReducePrecision 运算符。
reduce_window(operand, init_value, …[, …])
rem(x, y) 按位取余:(x \bmod y)。
reshape(operand, new_sizes[, dimensions]) 封装了 XLA 的 Reshape 运算符。
rev(operand, dimensions) 封装了 XLA 的 Rev 运算符。
rng_bit_generator(key, shape[, dtype, algorithm]) 无状态的伪随机数位生成器。
rng_uniform(a, b, shape) 有状态的伪随机数生成器。
round(x[, rounding_method]) 按位四舍五入。
rsqrt(x) 按位倒数平方根:(1 \over \sqrt{x})。
scatter(operand, scatter_indices, updates, …) Scatter-update 运算符。
scatter_add(operand, scatter_indices, …[, …]) Scatter-add 运算符。
scatter_apply(operand, scatter_indices, …) Scatter-apply 运算符。
scatter_max(operand, scatter_indices, …[, …]) Scatter-max 运算符。
scatter_min(operand, scatter_indices, …[, …]) Scatter-min 运算符。
scatter_mul(operand, scatter_indices, …[, …]) Scatter-multiply 运算符。
shift_left(x, y) 按位左移:(x \ll y)。
shift_right_arithmetic(x, y) 按位算术右移:(x \gg y)。
shift_right_logical(x, y) 按位逻辑右移:(x \gg y)。
sign(x) 按位符号函数。
sin(x) 按位正弦函数:(\mathrm{sin}(x))。
sinh(x) 按位双曲正弦函数:(\mathrm{sinh}(x))。
slice(operand, start_indices, limit_indices) 封装了 XLA 的 Slice 运算符。
slice_in_dim(operand, start_index, limit_index) lax.slice() 的单维度应用封装。
sort() 封装了 XLA 的 Sort 运算符。
sort_key_val(keys, values[, dimension, …]) 沿着dimension排序keys并对values应用相同的置换。
sqrt(x) 逐元素平方根:(\sqrt{x})。
square(x) 逐元素平方:(x²)。
squeeze(array, dimensions) 从数组中挤出任意数量的大小为 1 的维度。
sub(x, y) 逐元素减法:(x - y)。
tan(x) 逐元素正切:(\mathrm{tan}(x))。
tanh(x) 逐元素双曲正切:(\mathrm{tanh}(x))。
top_k(operand, k) 返回operand最后一轴上的前k个值及其索引。
transpose(operand, permutation) 包装 XLA 的Transpose运算符。
zeros_like_array(x)
zeta(x, q) 逐元素 Hurwitz zeta 函数:(\zeta(x, q))

控制流操作符

associative_scan(fn, elems[, reverse, axis]) 使用关联二元操作并行执行扫描。
cond(pred, true_fun, false_fun, *operands[, …]) 根据条件应用true_funfalse_fun
fori_loop(lower, upper, body_fun, init_val, *) 通过归约到jax.lax.while_loop()lowerupper循环。
map(f, xs) 在主要数组轴上映射函数。
scan(f, init[, xs, length, reverse, unroll, …]) 在主要数组轴上扫描函数并携带状态。
select(pred, on_true, on_false) 根据布尔谓词在两个分支之间选择。
select_n(which, *cases) 从多个情况中选择数组值。
switch(index, branches, *operands[, operand]) 根据index应用恰好一个branches
while_loop(cond_fun, body_fun, init_val) cond_fun为 True 时重复调用body_fun

自定义梯度操作符

stop_gradient(x) 停止梯度计算。
custom_linear_solve(matvec, b, solve[, …]) 使用隐式定义的梯度执行无矩阵线性求解。
custom_root(f, initial_guess, solve, …[, …]) 可微分求解函数的根。

并行操作符

all_gather(x, axis_name, *[, …]) 在所有副本中收集x的值。
all_to_all(x, axis_name, split_axis, …[, …]) 映射轴的实例化和映射不同轴。
pdot(x, y, axis_name[, pos_contract, …])
psum(x, axis_name, *[, axis_index_groups]) 在映射的轴axis_name上进行全归约求和。
psum_scatter(x, axis_name, *[, …]) psum(x, axis_name),但每个设备仅保留部分结果。
pmax(x, axis_name, *[, axis_index_groups]) 在映射的轴axis_name上计算全归约最大值。
pmin(x, axis_name, *[, axis_index_groups]) 在映射的轴axis_name上计算全归约最小值。
pmean(x, axis_name, *[, axis_index_groups]) 在映射的轴axis_name上计算全归约均值。
ppermute(x, axis_name, perm) 根据置换 perm 执行集体置换。
pshuffle(x, axis_name, perm) 使用替代置换编码的 jax.lax.ppermute 的便捷包装器
pswapaxes(x, axis_name, axis, *[, …]) 将 pmapped 轴 axis_name 与非映射轴 axis 交换。
axis_index(axis_name) 返回沿映射轴 axis_name 的索引。

与分片相关的操作符

with_sharding_constraint(x, shardings) 在 jitted 计算中约束数组的分片机制

线性代数操作符 (jax.lax.linalg)

cholesky(x, *[, symmetrize_input]) Cholesky 分解。
eig(x, *[, compute_left_eigenvectors, …]) 一般矩阵的特征分解。
eigh(x, *[, lower, symmetrize_input, …]) Hermite 矩阵的特征分解。
hessenberg(a) 将方阵约化为上 Hessenberg 形式。
lu(x) 带有部分主元列主元分解。
householder_product(a, taus) 单元 Householder 反射的乘积。
qdwh(x, *[, is_hermitian, max_iterations, …]) 基于 QR 的动态加权 Halley 迭代进行极分解。
qr(x, *[, full_matrices]) QR 分解。
schur(x, *[, compute_schur_vectors, …])
svd() 奇异值分解。
triangular_solve(a, b, *[, left_side, …]) 三角解法。
tridiagonal(a, *[, lower]) 将对称/Hermitian 矩阵约化为三对角形式。
tridiagonal_solve(dl, d, du, b) 计算三对角线性系统的解。

参数类

class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)

描述卷积的批量、空间和特征维度。

参数:

  • lhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
  • rhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(输出特征维度,输入特征维度,空间维度…)。
  • out_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
jax.lax.ConvGeneralDilatedDimensionNumbers

alias of tuple[str, str, str] | ConvDimensionNumbers | None

class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)

描述了传递给 XLA 的 Gather 运算符 的维度号参数。有关维度号含义的详细信息,请参阅 XLA 文档。

Parameters:

  • offset_dims (tuple[int, …**]) – gather 输出中偏移到从操作数切片的数组中的维度的集合。必须是升序整数元组,每个代表输出的一个维度编号。
  • collapsed_slice_dims (tuple[int, …**]) – operand 中具有 slice_sizes[i] == 1 的维度 i 的集合,这些维度不应在 gather 输出中具有对应维度。必须是一个升序整数元组。
  • start_index_map (tuple[int, …**]) – 对于 start_indices 中的每个维度,给出应该被切片的操作数中对应的维度。必须是一个大小等于 start_indices.shape[-1] 的整数元组。

与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐含的;总是存在一个索引向量维度,且它必须始终是最后一个维度。要收集标量索引,请添加大小为 1 的尾随维度。

class jax.lax.GatherScatterMode(value)

描述了如何处理 gather 或 scatter 中的越界索引。

可能的值包括:

CLIP:

索引将被夹在最近的范围内值上,即整个要收集的窗口都在范围内。

FILL_OR_DROP:

如果收集窗口的任何部分越界,则返回整个窗口,即使其他部分原本在界内的元素也将用常量填充。如果分散窗口的任何部分越界,则整个窗口将被丢弃。

PROMISE_IN_BOUNDS:

用户承诺索引在范围内。不会执行额外检查。实际上,根据当前的 XLA 实现,这意味着越界的 gather 将被夹在范围内,但越界的 scatter 将被丢弃。如果索引越界,则梯度将不正确。

class jax.lax.Precision(value)

lax 函数的精度枚举

JAX 函数的精度参数通常控制加速器后端(即 TPU 和 GPU)上的数组计算速度和精度之间的权衡。成员包括:

默认:

最快模式,但最不准确。在 bfloat16 中执行计算。别名:'default''fastest''bfloat16'

高:

较慢但更准确。以 3 个 bfloat16 传递执行 float32 计算,或在可用时使用 tensorfloat32。别名:'high''bfloat16_3x''tensorfloat32'

最高:

最慢但最准确。根据适用情况在 float32 或 float64 中执行计算。别名:'highest''float32'

jax.lax.PrecisionLike

别名为 str | Precision | tuple[str, str] | tuple[Precision, Precision] | None

class jax.lax.RoundingMethod(value)

一个枚举。

class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)

描述了对 XLA 的 Scatter 操作符 的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。

参数:

  • update_window_dims (Sequence[int]) – 更新中作为窗口维度的维度集合。必须是整数元组,按升序排列,每个表示一个维度编号。
  • inserted_window_dims (Sequence[int]) – 必须插入更新形状的大小为 1 的窗口维度集合。必须是整数元组,按升序排列,每个表示输出的维度编号的镜像图。这些是 gather 情况下 collapsed_slice_dims 的镜像图。
  • scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是整数序列,大小等于 scatter_indices.shape[-1]。

与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是有一个索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,添加一个尺寸为 1 的尾随维度。


JAX 中文文档(十四)(3)https://developer.aliyun.com/article/1559758

相关文章
|
4月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
53 3
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
40 3
|
4月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
51 2
|
4月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
55 1
|
4月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
33 1
|
4月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
17 1
|
4月前
|
存储 缓存 API
JAX 中文文档(十六)(1)
JAX 中文文档(十六)
35 1
|
4月前
|
关系型数据库
JAX 中文文档(十四)(1)
JAX 中文文档(十四)
33 0
|
4月前
|
资源调度 算法 安全
JAX 中文文档(十四)(3)
JAX 中文文档(十四)
45 0
|
4月前
|
Python
JAX 中文文档(十四)(4)
JAX 中文文档(十四)
32 0