JAX 中文文档(十四)(1)https://developer.aliyun.com/article/1559755
jax.lax 模块
jax.lax
是支持诸如 jax.numpy
等库的基本操作的库。通常会定义转换规则,例如 JVP 和批处理规则,作为对 jax.lax
基元的转换。
许多基元都是等价于 XLA 操作的薄包装,详细描述请参阅XLA 操作语义文档。
在可能的情况下,优先使用诸如 jax.numpy
等库,而不是直接使用 jax.lax
。jax.numpy
API 遵循 NumPy,因此比 jax.lax
API 更稳定,更不易更改。
Operators
abs (x) |
按元素绝对值:(|x|)。 |
acos (x) |
按元素求反余弦:(\mathrm{acos}(x))。 |
acosh (x) |
按元素求反双曲余弦:(\mathrm{acosh}(x))。 |
add (x, y) |
按元素加法:(x + y)。 |
after_all (*operands) |
合并一个或多个 XLA 令牌值。 |
approx_max_k (operand, k[, …]) |
以近似方式返回 operand 的最大 k 值及其索引。 |
approx_min_k (operand, k[, …]) |
以近似方式返回 operand 的最小 k 值及其索引。 |
argmax (operand, axis, index_dtype) |
计算沿着 axis 的最大元素的索引。 |
argmin (operand, axis, index_dtype) |
计算沿着 axis 的最小元素的索引。 |
asin (x) |
按元素求反正弦:(\mathrm{asin}(x))。 |
asinh (x) |
按元素求反双曲正弦:(\mathrm{asinh}(x))。 |
atan (x) |
按元素求反正切:(\mathrm{atan}(x))。 |
atan2 (x, y) |
两个变量的按元素反正切:(\mathrm{atan}({x \over y}))。 |
atanh (x) |
按元素求反双曲正切:(\mathrm{atanh}(x))。 |
batch_matmul (lhs, rhs[, precision]) |
批量矩阵乘法。 |
bessel_i0e (x) |
指数缩放修正贝塞尔函数 (0) 阶:(\mathrm{i0e}(x) = e^{-|x|} \mathrm{i0}(x)) |
bessel_i1e (x) |
指数缩放修正贝塞尔函数 (1) 阶:(\mathrm{i1e}(x) = e^{-|x|} \mathrm{i1}(x)) |
betainc (a, b, x) |
按元素的正则化不完全贝塔积分。 |
bitcast_convert_type (operand, new_dtype) |
按元素位转换。 |
bitwise_and (x, y) |
按位与运算:(x \wedge y)。 |
bitwise_not (x) |
按位取反:(\neg x)。 |
bitwise_or (x, y) |
按位或运算:(x \vee y)。 |
bitwise_xor (x, y) |
按位异或运算:(x \oplus y)。 |
population_count (x) |
按元素计算 popcount,即每个元素中设置的位数。 |
broadcast (operand, sizes) |
广播数组,添加新的前导维度。 |
broadcast_in_dim (operand, shape, …) |
包装 XLA 的 BroadcastInDim 操作符。 |
broadcast_shapes () |
返回经过 NumPy 广播后的形状。 |
broadcast_to_rank (x, rank) |
添加 1 的前导维度,使 x 的等级为 rank 。 |
broadcasted_iota (dtype, shape, dimension) |
iota 的便捷封装器。 |
cbrt (x) |
元素级立方根:(\sqrt[3]{x})。 |
ceil (x) |
元素级向上取整:(\left\lceil x \right\rceil)。 |
clamp (min, x, max) |
元素级 clamp 函数。 |
clz (x) |
元素级计算前导零的个数。 |
collapse (operand, start_dimension[, …]) |
将数组的维度折叠为单个维度。 |
complex (x, y) |
元素级构造复数:(x + jy)。 |
concatenate (operands, dimension) |
沿指定维度连接一系列数组。 |
conj (x) |
元素级复数的共轭函数:(\overline{x})。 |
conv (lhs, rhs, window_strides, padding[, …]) |
conv_general_dilated 的便捷封装器。 |
convert_element_type (operand, new_dtype) |
元素级类型转换。 |
conv_dimension_numbers (lhs_shape, rhs_shape, …) |
将卷积维度编号转换为 ConvDimensionNumbers 。 |
conv_general_dilated (lhs, rhs, …[, …]) |
带有可选扩展的通用 n 维卷积运算符。 |
conv_general_dilated_local (lhs, rhs, …[, …]) |
带有可选扩展的通用 n 维非共享卷积运算符。 |
conv_general_dilated_patches (lhs, …[, …]) |
提取符合 conv_general_dilated 接受域的补丁。 |
conv_transpose (lhs, rhs, strides, padding[, …]) |
计算 N 维卷积的“转置”的便捷封装器。 |
conv_with_general_padding (lhs, rhs, …[, …]) |
conv_general_dilated 的便捷封装器。 |
cos (x) |
元素级余弦函数:(\mathrm{cos}(x))。 |
cosh (x) |
元素级双曲余弦函数:(\mathrm{cosh}(x))。 |
cumlogsumexp (operand[, axis, reverse]) |
沿轴计算累积 logsumexp。 |
cummax (operand[, axis, reverse]) |
沿轴计算累积最大值。 |
cummin (operand[, axis, reverse]) |
沿轴计算累积最小值。 |
cumprod (operand[, axis, reverse]) |
沿轴计算累积乘积。 |
cumsum (operand[, axis, reverse]) |
沿轴计算累积和。 |
digamma (x) |
元素级 digamma 函数:(\psi(x))。 |
div (x, y) |
元素级除法:(x \over y)。 |
dot (lhs, rhs[, precision, …]) |
向量/向量,矩阵/向量和矩阵/矩阵乘法。 |
dot_general (lhs, rhs, dimension_numbers[, …]) |
通用的点积/收缩运算符。 |
dynamic_index_in_dim (operand, index[, axis, …]) |
对 dynamic_slice 的便捷封装,用于执行整数索引。 |
dynamic_slice (operand, start_indices, …) |
封装了 XLA 的 DynamicSlice 操作符。 |
dynamic_slice_in_dim (operand, start_index, …) |
方便地封装了应用于单个维度的 lax.dynamic_slice() 。 |
dynamic_update_index_in_dim (operand, update, …) |
方便地封装了 dynamic_update_slice() ,用于在单个 axis 中更新大小为 1 的切片。 |
dynamic_update_slice (operand, update, …) |
封装了 XLA 的 DynamicUpdateSlice 操作符。 |
dynamic_update_slice_in_dim (operand, update, …) |
方便地封装了 dynamic_update_slice() ,用于在单个 axis 中更新一个切片。 |
eq (x, y) |
元素级相等:(x = y)。 |
erf (x) |
元素级误差函数:(\mathrm{erf}(x))。 |
erfc (x) |
元素级补充误差函数:(\mathrm{erfc}(x) = 1 - \mathrm{erf}(x))。 |
erf_inv (x) |
元素级反误差函数:(\mathrm{erf}^{-1}(x))。 |
exp (x) |
元素级指数函数:(e^x)。 |
expand_dims (array, dimensions) |
将任意数量的大小为 1 的维度插入到数组中。 |
expm1 (x) |
元素级运算 (e^{x} - 1)。 |
fft (x, fft_type, fft_lengths) |
|
floor (x) |
元素级向下取整:(\left\lfloor x \right\rfloor)。 |
full (shape, fill_value[, dtype, sharding]) |
返回填充值为 fill_value 的形状数组。 |
full_like (x, fill_value[, dtype, shape, …]) |
基于示例数组 x 创建类似于 np.full 的完整数组。 |
gather (operand, start_indices, …[, …]) |
Gather 操作符。 |
ge (x, y) |
元素级大于或等于:(x \geq y)。 |
gt (x, y) |
元素级大于:(x > y)。 |
igamma (a, x) |
元素级正则化不完全 gamma 函数。 |
igammac (a, x) |
元素级补充正则化不完全 gamma 函数。 |
imag (x) |
提取复数的虚部:(\mathrm{Im}(x))。 |
index_in_dim (operand, index[, axis, keepdims]) |
方便地封装了 lax.slice() ,用于执行整数索引。 |
index_take (src, idxs, axes) |
|
integer_pow (x, y) |
元素级幂运算:(x^y),其中 (y) 是固定整数。 |
iota (dtype, size) |
封装了 XLA 的 Iota 操作符。 |
is_finite (x) |
元素级 (\mathrm{isfinite})。 |
le (x, y) |
元素级小于或等于:(x \leq y)。 |
lgamma (x) |
元素级对数 gamma 函数:(\mathrm{log}(\Gamma(x)))。 |
log (x) |
元素级自然对数:(\mathrm{log}(x))。 |
log1p (x) |
元素级 (\mathrm{log}(1 + x))。 |
logistic (x) |
元素级 logistic(sigmoid)函数:(\frac{1}{1 + e^{-x}})。 |
lt (x, y) |
元素级小于:(x < y)。 |
max (x, y) |
元素级最大值:(\mathrm{max}(x, y)) |
min (x, y) |
元素级最小值:(\mathrm{min}(x, y)) |
mul (x, y) |
元素级乘法:(x \times y)。 |
ne (x, y) |
按位不等于:(x \neq y)。 |
neg (x) |
按位取负:(-x)。 |
nextafter (x1, x2) |
返回 x1 在 x2 方向上的下一个可表示的值。 |
pad (operand, padding_value, padding_config) |
对数组应用低、高和/或内部填充。 |
polygamma (m, x) |
按位多次 gamma 函数:(\psi^{(m)}(x))。 |
population_count (x) |
按位人口统计,统计每个元素中设置的位数。 |
pow (x, y) |
按位幂运算:(x^y)。 |
random_gamma_grad (a, x) |
Gamma 分布导数的按位计算。 |
real (x) |
按位提取实部:(\mathrm{Re}(x))。 |
reciprocal (x) |
按位倒数:(1 \over x)。 |
reduce (operands, init_values, computation, …) |
封装了 XLA 的 Reduce 运算符。 |
reduce_precision (operand, exponent_bits, …) |
封装了 XLA 的 ReducePrecision 运算符。 |
reduce_window (operand, init_value, …[, …]) |
|
rem (x, y) |
按位取余:(x \bmod y)。 |
reshape (operand, new_sizes[, dimensions]) |
封装了 XLA 的 Reshape 运算符。 |
rev (operand, dimensions) |
封装了 XLA 的 Rev 运算符。 |
rng_bit_generator (key, shape[, dtype, algorithm]) |
无状态的伪随机数位生成器。 |
rng_uniform (a, b, shape) |
有状态的伪随机数生成器。 |
round (x[, rounding_method]) |
按位四舍五入。 |
rsqrt (x) |
按位倒数平方根:(1 \over \sqrt{x})。 |
scatter (operand, scatter_indices, updates, …) |
Scatter-update 运算符。 |
scatter_add (operand, scatter_indices, …[, …]) |
Scatter-add 运算符。 |
scatter_apply (operand, scatter_indices, …) |
Scatter-apply 运算符。 |
scatter_max (operand, scatter_indices, …[, …]) |
Scatter-max 运算符。 |
scatter_min (operand, scatter_indices, …[, …]) |
Scatter-min 运算符。 |
scatter_mul (operand, scatter_indices, …[, …]) |
Scatter-multiply 运算符。 |
shift_left (x, y) |
按位左移:(x \ll y)。 |
shift_right_arithmetic (x, y) |
按位算术右移:(x \gg y)。 |
shift_right_logical (x, y) |
按位逻辑右移:(x \gg y)。 |
sign (x) |
按位符号函数。 |
sin (x) |
按位正弦函数:(\mathrm{sin}(x))。 |
sinh (x) |
按位双曲正弦函数:(\mathrm{sinh}(x))。 |
slice (operand, start_indices, limit_indices) |
封装了 XLA 的 Slice 运算符。 |
slice_in_dim (operand, start_index, limit_index) |
lax.slice() 的单维度应用封装。 |
sort () |
封装了 XLA 的 Sort 运算符。 |
sort_key_val (keys, values[, dimension, …]) |
沿着dimension 排序keys 并对values 应用相同的置换。 |
sqrt (x) |
逐元素平方根:(\sqrt{x})。 |
square (x) |
逐元素平方:(x²)。 |
squeeze (array, dimensions) |
从数组中挤出任意数量的大小为 1 的维度。 |
sub (x, y) |
逐元素减法:(x - y)。 |
tan (x) |
逐元素正切:(\mathrm{tan}(x))。 |
tanh (x) |
逐元素双曲正切:(\mathrm{tanh}(x))。 |
top_k (operand, k) |
返回operand 最后一轴上的前k 个值及其索引。 |
transpose (operand, permutation) |
包装 XLA 的Transpose运算符。 |
zeros_like_array (x) |
|
zeta (x, q) |
逐元素 Hurwitz zeta 函数:(\zeta(x, q)) |
控制流操作符
associative_scan (fn, elems[, reverse, axis]) |
使用关联二元操作并行执行扫描。 |
cond (pred, true_fun, false_fun, *operands[, …]) |
根据条件应用true_fun 或false_fun 。 |
fori_loop (lower, upper, body_fun, init_val, *) |
通过归约到jax.lax.while_loop() 从lower 到upper 循环。 |
map (f, xs) |
在主要数组轴上映射函数。 |
scan (f, init[, xs, length, reverse, unroll, …]) |
在主要数组轴上扫描函数并携带状态。 |
select (pred, on_true, on_false) |
根据布尔谓词在两个分支之间选择。 |
select_n (which, *cases) |
从多个情况中选择数组值。 |
switch (index, branches, *operands[, operand]) |
根据index 应用恰好一个branches 。 |
while_loop (cond_fun, body_fun, init_val) |
在cond_fun 为 True 时重复调用body_fun 。 |
自定义梯度操作符
stop_gradient (x) |
停止梯度计算。 |
custom_linear_solve (matvec, b, solve[, …]) |
使用隐式定义的梯度执行无矩阵线性求解。 |
custom_root (f, initial_guess, solve, …[, …]) |
可微分求解函数的根。 |
并行操作符
all_gather (x, axis_name, *[, …]) |
在所有副本中收集x 的值。 |
all_to_all (x, axis_name, split_axis, …[, …]) |
映射轴的实例化和映射不同轴。 |
pdot (x, y, axis_name[, pos_contract, …]) |
|
psum (x, axis_name, *[, axis_index_groups]) |
在映射的轴axis_name 上进行全归约求和。 |
psum_scatter (x, axis_name, *[, …]) |
像psum(x, axis_name) ,但每个设备仅保留部分结果。 |
pmax (x, axis_name, *[, axis_index_groups]) |
在映射的轴axis_name 上计算全归约最大值。 |
pmin (x, axis_name, *[, axis_index_groups]) |
在映射的轴axis_name 上计算全归约最小值。 |
pmean (x, axis_name, *[, axis_index_groups]) |
在映射的轴axis_name 上计算全归约均值。 |
ppermute (x, axis_name, perm) |
根据置换 perm 执行集体置换。 |
pshuffle (x, axis_name, perm) |
使用替代置换编码的 jax.lax.ppermute 的便捷包装器 |
pswapaxes (x, axis_name, axis, *[, …]) |
将 pmapped 轴 axis_name 与非映射轴 axis 交换。 |
axis_index (axis_name) |
返回沿映射轴 axis_name 的索引。 |
与分片相关的操作符
with_sharding_constraint (x, shardings) |
在 jitted 计算中约束数组的分片机制 |
线性代数操作符 (jax.lax.linalg)
cholesky (x, *[, symmetrize_input]) |
Cholesky 分解。 |
eig (x, *[, compute_left_eigenvectors, …]) |
一般矩阵的特征分解。 |
eigh (x, *[, lower, symmetrize_input, …]) |
Hermite 矩阵的特征分解。 |
hessenberg (a) |
将方阵约化为上 Hessenberg 形式。 |
lu (x) |
带有部分主元列主元分解。 |
householder_product (a, taus) |
单元 Householder 反射的乘积。 |
qdwh (x, *[, is_hermitian, max_iterations, …]) |
基于 QR 的动态加权 Halley 迭代进行极分解。 |
qr (x, *[, full_matrices]) |
QR 分解。 |
schur (x, *[, compute_schur_vectors, …]) |
|
svd () |
奇异值分解。 |
triangular_solve (a, b, *[, left_side, …]) |
三角解法。 |
tridiagonal (a, *[, lower]) |
将对称/Hermitian 矩阵约化为三对角形式。 |
tridiagonal_solve (dl, d, du, b) |
计算三对角线性系统的解。 |
参数类
class jax.lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
描述卷积的批量、空间和特征维度。
参数:
- lhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
- rhs_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(输出特征维度,输入特征维度,空间维度…)。
- out_spec (Sequence[int]) – 包含非负整数维度编号的元组,其中包括(批量维度,特征维度,空间维度…)。
jax.lax.ConvGeneralDilatedDimensionNumbers
alias of tuple
[str
, str
, str
] | ConvDimensionNumbers
| None
class jax.lax.GatherDimensionNumbers(offset_dims, collapsed_slice_dims, start_index_map)
描述了传递给 XLA 的 Gather 运算符 的维度号参数。有关维度号含义的详细信息,请参阅 XLA 文档。
Parameters:
- offset_dims (tuple[int, …**]) – gather 输出中偏移到从操作数切片的数组中的维度的集合。必须是升序整数元组,每个代表输出的一个维度编号。
- collapsed_slice_dims (tuple[int, …**]) – operand 中具有 slice_sizes[i] == 1 的维度 i 的集合,这些维度不应在 gather 输出中具有对应维度。必须是一个升序整数元组。
- start_index_map (tuple[int, …**]) – 对于 start_indices 中的每个维度,给出应该被切片的操作数中对应的维度。必须是一个大小等于 start_indices.shape[-1] 的整数元组。
与 XLA 的 GatherDimensionNumbers 结构不同,index_vector_dim 是隐含的;总是存在一个索引向量维度,且它必须始终是最后一个维度。要收集标量索引,请添加大小为 1 的尾随维度。
class jax.lax.GatherScatterMode(value)
描述了如何处理 gather 或 scatter 中的越界索引。
可能的值包括:
CLIP:
索引将被夹在最近的范围内值上,即整个要收集的窗口都在范围内。
FILL_OR_DROP:
如果收集窗口的任何部分越界,则返回整个窗口,即使其他部分原本在界内的元素也将用常量填充。如果分散窗口的任何部分越界,则整个窗口将被丢弃。
PROMISE_IN_BOUNDS:
用户承诺索引在范围内。不会执行额外检查。实际上,根据当前的 XLA 实现,这意味着越界的 gather 将被夹在范围内,但越界的 scatter 将被丢弃。如果索引越界,则梯度将不正确。
class jax.lax.Precision(value)
lax 函数的精度枚举
JAX 函数的精度参数通常控制加速器后端(即 TPU 和 GPU)上的数组计算速度和精度之间的权衡。成员包括:
默认:
最快模式,但最不准确。在 bfloat16 中执行计算。别名:'default'
,'fastest'
,'bfloat16'
。
高:
较慢但更准确。以 3 个 bfloat16 传递执行 float32 计算,或在可用时使用 tensorfloat32。别名:'high'
,'bfloat16_3x'
,'tensorfloat32'
。
最高:
最慢但最准确。根据适用情况在 float32 或 float64 中执行计算。别名:'highest'
,'float32'
。
jax.lax.PrecisionLike
别名为 str
| Precision
| tuple
[str
, str
] | tuple
[Precision
, Precision
] | None
class jax.lax.RoundingMethod(value)
一个枚举。
class jax.lax.ScatterDimensionNumbers(update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)
描述了对 XLA 的 Scatter 操作符 的维度编号参数。有关维度编号含义的更多详细信息,请参阅 XLA 文档。
参数:
- update_window_dims (Sequence[int]) – 更新中作为窗口维度的维度集合。必须是整数元组,按升序排列,每个表示一个维度编号。
- inserted_window_dims (Sequence[int]) – 必须插入更新形状的大小为 1 的窗口维度集合。必须是整数元组,按升序排列,每个表示输出的维度编号的镜像图。这些是 gather 情况下 collapsed_slice_dims 的镜像图。
- scatter_dims_to_operand_dims (Sequence[int]) – 对于 scatter_indices 中的每个维度,给出 operand 中对应的维度。必须是整数序列,大小等于 scatter_indices.shape[-1]。
与 XLA 的 ScatterDimensionNumbers 结构不同,index_vector_dim 是隐式的;总是有一个索引向量维度,并且它必须始终是最后一个维度。要分散标量索引,添加一个尺寸为 1 的尾随维度。
JAX 中文文档(十四)(3)https://developer.aliyun.com/article/1559758