JAX 中文文档(十三)(4)

简介: JAX 中文文档(十三)

JAX 中文文档(十三)(3)https://developer.aliyun.com/article/1559743


jax.numpy 模块

原文:jax.readthedocs.io/en/latest/jax.numpy.html

采用jax.lax中的原语实现 NumPy API。

虽然 JAX 尽可能地遵循 NumPy API,但有时无法完全遵循 NumPy 的规范。

  • 值得注意的是,由于 JAX 数组是不可变的,不能在 JAX 中实现原地变换数组的 NumPy API。但是,JAX 通常能够提供纯函数的替代 API。例如,替代原地数组更新(x[i] = y),JAX 提供了一个纯索引更新函数 x.at[i].set(y)(参见ndarray.at)。
  • 类似地,一些 NumPy 函数在可能时经常返回数组的视图(例如transpose()reshape())。JAX 版本的这类函数将返回副本,尽管在使用jax.jit()编译操作序列时,XLA 通常会进行优化。
  • NumPy 在将值提升为float64类型时非常积极。JAX 在类型提升方面有时不那么积极(请参阅类型提升语义)。
  • 一些 NumPy 例程具有依赖数据的输出形状(例如unique()nonzero())。因为 XLA 编译器要求在编译时知道数组形状,这些操作与 JIT 不兼容。因此,JAX 在这些函数中添加了一个可选的size参数,可以在静态指定以便与 JIT 一起使用。

几乎所有适用的 NumPy 函数都在jax.numpy命名空间中实现;它们如下所列。

ndarray.at 用于索引更新功能的辅助属性。
abs(x, /) jax.numpy.absolute()的别名。
absolute(x, /) 计算逐元素的绝对值。
acos(x, /) 逐元素的反余弦函数。
acosh(x, /) 逐元素的反双曲余弦函数。
add(x1, x2, /) 逐元素相加。
all(a[, axis, out, keepdims, where]) 测试沿给定轴的所有数组元素是否为 True。
allclose(a, b[, rtol, atol, equal_nan]) 如果两个数组在容差范围内逐元素相等,则返回 True。
amax(a[, axis, out, keepdims, initial, where]) 返回数组或沿轴的最大值。
amin(a[, axis, out, keepdims, initial, where]) 返回数组或沿轴的最小值。
angle(z[, deg]) 返回复数或数组的角度。
any(a[, axis, out, keepdims, where]) 测试沿给定轴的任何数组元素是否为 True。
append(arr, values[, axis]) 返回将值附加到原始数组末尾的新数组。
apply_along_axis(func1d, axis, arr, *args, …) 沿给定轴向数组的 1-D 切片应用函数。
apply_over_axes(func, a, axes) 在多个轴上重复应用函数。
arange(start[, stop, step, dtype]) 返回给定间隔内的均匀间隔值。
arccos(x, /) 反余弦,逐元素计算。
arccosh(x, /) 逆双曲余弦,逐元素计算。
arcsin(x, /) 反正弦,逐元素计算。
arcsinh(x, /) 逆双曲正弦,逐元素计算。
arctan(x, /) 反三角正切,逐元素计算。
arctan2(x1, x2, /) 根据 x1/x2 的值选择正确的象限,逐元素计算反正切。
arctanh(x, /) 逆双曲正切,逐元素计算。
argmax(a[, axis, out, keepdims]) 返回沿轴的最大值的索引。
argmin(a[, axis, out, keepdims]) 返回沿轴的最小值的索引。
argpartition(a, kth[, axis]) 返回部分排序数组的索引。
argsort(a[, axis, kind, order, stable, …]) 返回排序数组的索引。
argwhere(a, *[, size, fill_value]) 查找非零数组元素的索引。
around(a[, decimals, out]) 将数组四舍五入到指定的小数位数。
array(object[, dtype, copy, order, ndmin]) 创建一个数组。
array_equal(a1, a2[, equal_nan]) 如果两个数组具有相同的形状和元素则返回 True。
array_equiv(a1, a2) 如果输入数组形状一致且所有元素相等则返回 True。
array_repr(arr[, max_line_width, precision, …]) 返回数组的字符串表示。
array_split(ary, indices_or_sections[, axis]) 将数组分割为多个子数组。
array_str(a[, max_line_width, precision, …]) 返回数组中数据的字符串表示。
asarray(a[, dtype, order, copy]) 将输入转换为数组。
asin(x, /) 反正弦,逐元素计算。
asinh(x, /) 逆双曲正弦,逐元素计算。
astype(x, dtype, /, *[, copy, device]) 将数组复制到指定的数据类型。
atan(x, /) 反三角正切,逐元素计算。
atanh(x, /) 逆双曲正切,逐元素计算。
atan2(x1, x2, /) 根据 x1/x2 的值选择正确的象限,逐元素计算反正切。
atleast_1d() 将输入转换为至少有一维的数组。
atleast_2d() 将输入视为至少有两个维度的数组。
atleast_3d() 将输入视为至少有三个维度的数组。
average() 沿指定轴计算加权平均值。
bartlett(M) 返回 Bartlett 窗口。
bincount(x[, weights, minlength, length]) 计算整数数组中每个值的出现次数。
bitwise_and(x1, x2, /) 逐元素计算两个数组的按位与操作。
bitwise_count(x, /) 计算每个元素的绝对值的二进制表示中 1 的位数。
bitwise_invert(x, /) 计算按位求反,逐元素计算。
bitwise_left_shift(x1, x2, /) 将整数的位向左移动。
bitwise_not(x, /) 计算按位取反(bit-wise NOT),即按位取反,对每个元素进行操作。
bitwise_or(x1, x2, /) 计算两个数组按位或的结果。
bitwise_right_shift(x1, x2, /) 将整数的位向右移动。
bitwise_xor(x1, x2, /) 计算两个数组按位异或的结果。
blackman(M) 返回 Blackman 窗口。
block(arrays) 从嵌套的块列表中组装一个多维数组。
bool_ bool 的别名
broadcast_arrays(*args) 广播任意数量的数组。
broadcast_shapes() 将输入的形状广播为单个形状。
broadcast_to(array, shape) 将数组广播到新的形状。
c_ 沿着最后一个轴连接切片、标量和类数组对象。
can_cast(from_, to[, casting]) 根据转换规则,如果可以进行数据类型转换,则返回 True。
cbrt(x, /) 返回数组的立方根,按元素操作。
cdouble complex128 的别名
ceil(x, /) 返回输入的上限值,按元素操作。
character() 所有字符字符串标量类型的抽象基类。
choose(a, choices[, out, mode]) 根据索引数组和数组列表选择构造数组。
clip([x, min, max, a, a_min, a_max]) 将数组中的值限制在给定范围内。
column_stack(tup) 将一维数组按列堆叠成二维数组。
complex_ complex128 的别名
complex128(x)
complex64(x)
complexfloating() 所有由浮点数构成的复数数值标量类型的抽象基类。
ComplexWarning 在将复数数据类型强制转换为实数数据类型时引发的警告。
compress(condition, a[, axis, size, …]) 使用布尔条件沿指定轴压缩数组。
concat(arrays, /, *[, axis]) 沿着现有轴连接一系列数组。
concatenate(arrays[, axis, dtype]) 沿着指定轴连接一系列数组。
conj(x, /) 返回复数的共轭,按元素操作。
conjugate(x, /) 返回复数的共轭,按元素操作。
convolve(a, v[, mode, precision, …]) 计算两个一维数组的卷积。
copy(a[, order]) 返回给定对象的数组副本。
copysign(x1, x2, /) 将 x1 的符号改为 x2 的符号,按元素操作。
corrcoef(x[, y, rowvar]) 返回皮尔逊积矩相关系数。
correlate(a, v[, mode, precision, …]) 计算两个一维数组的相关性。
cos(x, /) 计算元素的余弦值。
cosh(x, /) 双曲余弦,按元素操作。
count_nonzero(a[, axis, keepdims]) 统计数组a中的非零值数量。
cov(m[, y, rowvar, bias, ddof, fweights, …]) 估算给定数据和权重的协方差矩阵。
cross(a, b[, axisa, axisb, axisc, axis]) 返回两个(向量)数组的叉积。
csingle complex64的别名。
cumprod(a[, axis, dtype, out]) 返回沿给定轴的元素的累积乘积。
cumsum(a[, axis, dtype, out]) 返回沿给定轴的元素的累积和。
cumulative_sum(x, /, *[, axis, dtype, …])
deg2rad(x, /) 将角度从度转换为弧度。
degrees(x, /) 将弧度从弧度转换为度。
delete(arr, obj[, axis, assume_unique_indices]) 从数组中删除条目或条目。
diag(v[, k]) 提取对角线或构造对角线数组。
diag_indices(n[, ndim]) 返回访问数组主对角线的索引。
diag_indices_from(arr) 返回 n 维数组的主对角线的访问索引。
diagflat(v[, k]) 用扁平化输入创建一个二维数组的对角线。
diagonal(a[, offset, axis1, axis2]) 返回指定对角线。
diff(a[, n, axis, prepend, append]) 计算给定轴的第 n 个离散差异。
digitize(x, bins[, right]) 返回输入数组中每个值所属的箱体的索引。
divide(x1, x2, /) 按元素划分参数。
divmod(x1, x2, /) 同时返回按元素的商和余数。
dot(a, b, *[, precision, preferred_element_type]) 计算两个数组的点积。
double float64的别名。
dsplit(ary, indices_or_sections) 沿第 3 轴(深度)将数组分割成多个子数组。
dstack(tup[, dtype]) 深度方向上序列堆叠数组(沿着第三个轴)。
dtype(dtype[, align, copy]) 创建一个数据类型对象。
ediff1d(ary[, to_end, to_begin]) 数组中连续元素的差异。
einsum() 爱因斯坦求和。
einsum_path() 在不评估 einsum 的情况下计算最佳收缩路径。
empty(shape[, dtype, device]) 返回给定形状和类型的新数组,不初始化条目。
empty_like(prototype[, dtype, shape, device]) 返回与给定数组相同形状和类型的新数组。
equal(x1, x2, /) 按元素返回(x1 == x2)。
exp(x, /) 计算输入数组中所有元素的指数。
exp2(x, /) 计算输入数组中所有 p 的 2**p。
expand_dims(a, axis) 将长度为 1 的维度插入数组。
expm1(x, /) 计算数组中所有元素的exp(x) - 1
extract(condition, arr, *[, size, fill_value]) 返回满足条件的数组元素。
eye(N[, M, k, dtype]) 返回对角线上为 1 的二维数组,其他位置为 0。
fabs(x, /) 计算每个元素的绝对值。
fill_diagonal(a, val[, wrap, inplace]) 填充给定任意维度数组的主对角线。
finfo(dtype) 浮点类型的机器限制。
fix(x[, out]) 四舍五入到最近的整数朝向零。
flatnonzero(a, *[, size, fill_value]) 返回扁平化数组中非零元素的索引。
flexible() 所有没有预定义长度的标量类型的抽象基类。
flip(m[, axis]) 沿指定轴翻转数组元素的顺序。
fliplr(m) 沿轴 1 翻转数组元素的顺序。
flipud(m) 沿轴 0 翻转数组元素的顺序。
float_ float64 的别名。
float_power(x1, x2, /) 逐元素地将第一个数组的元素提升为第二个数组的幂。
float16(x)
float32(x)
float64(x)
floating() 所有浮点标量类型的抽象基类。
floor(x, /) 逐元素返回输入的下限。
floor_divide(x1, x2, /) 返回输入除法的最大整数小于或等于结果的元素。
fmax(x1, x2) 数组元素的逐元素最大值。
fmin(x1, x2) 数组元素的逐元素最小值。
fmod(x1, x2, /) 返回除法的元素余数。
frexp(x, /) 将 x 的元素分解为尾数和二次指数。
frombuffer(buffer[, dtype, count, offset]) 将缓冲区解释为一维数组。
fromfile(*args, **kwargs) jnp.fromfile 的未实现 JAX 封装器。
fromfunction(function, shape, *[, dtype]) 通过对每个坐标执行函数来构造数组。
fromiter(*args, **kwargs) jnp.fromiter 的未实现 JAX 封装器。
frompyfunc(func, /, nin, nout, *[, identity]) 从任意 JAX 兼容的标量函数创建一个 JAX ufunc。
fromstring(string[, dtype, count]) 从字符串中的文本数据初始化一个新的一维数组。
from_dlpack(x, /, *[, device, copy]) 从实现了__dlpack__的对象创建一个 NumPy 数组。
full(shape, fill_value[, dtype, device]) 返回给定形状和类型的新数组,并填充 fill_value。
full_like(a, fill_value[, dtype, shape, device]) 返回与给定数组形状和类型相同的全数组。
gcd(x1, x2) 返回 |x1||x2| 的最大公约数。
generic() NumPy 标量类型的基类。
geomspace(start, stop[, num, endpoint, …]) 返回等间隔的对数刻度上的数字(等比数列)。
get_printoptions() 返回当前的打印选项。
gradient(f, *varargs[, axis, edge_order]) 返回 N 维数组的梯度。
greater(x1, x2, /) 返回逐元素 (x1 > x2) 的真值。
greater_equal(x1, x2, /) 返回逐元素 (x1 >= x2) 的真值。
hamming(M) 返回 Hamming 窗口。
hanning(M) 返回 Hanning 窗口。
heaviside(x1, x2, /) 计算 Heaviside 阶跃函数。
histogram(a[, bins, range, weights, density]) 计算数据集的直方图。
histogram_bin_edges(a[, bins, range, weights]) 计算直方图使用的箱子的边缘。
histogram2d(x, y[, bins, range, weights, …]) 计算两个数据样本的二维直方图。
histogramdd(sample[, bins, range, weights, …]) 计算一些数据的多维直方图。
hsplit(ary, indices_or_sections) 水平(按列)将数组分割为多个子数组。
hstack(tup[, dtype]) 按序列水平(按列)堆叠数组。
hypot(x1, x2, /) 给定直角三角形的“腿”,返回其斜边长度。
i0 第一类修正贝塞尔函数,阶数为 0。
identity(n[, dtype]) 返回单位数组。
iinfo(int_type)
imag(val, /) 返回复数参数的虚部。
index_exp 用于构建数组索引元组的更好方式。
indices() 返回表示网格的索引数组。
inexact() 所有数值标量类型的抽象基类,其值的表示(可能)是不精确的,如浮点数。
inner(a, b, *[, precision, …]) 计算两个数组的内积。
insert(arr, obj, values[, axis]) 在给定索引之前,沿着指定的轴插入值。
int_ int64的别名
int16(x)
int32(x)
int64(x)
int8(x)
integer() 所有整数标量类型的抽象基类。
interp(x, xp, fp[, left, right, period]) 单调递增样本点的一维线性插值。
intersect1d(ar1, ar2[, assume_unique, …]) 计算两个一维数组的交集。
invert(x, /) 按位求反,即按位非,逐元素进行操作。
isclose(a, b[, rtol, atol, equal_nan]) 返回一个布尔数组,其中两个数组在每个元素级别上是否在指定的公差内相等。
iscomplex(x) 返回一个布尔数组,如果输入元素是复数则为 True。
iscomplexobj(x) 检查复数类型或复数数组。
isdtype(dtype, kind) 返回一个布尔值,指示提供的 dtype 是否属于指定的 kind。
isfinite(x, /) 测试每个元素是否有限(既不是无穷大也不是非数)。
isin(element, test_elements[, …]) 确定element中的元素是否出现在test_elements中。
isinf(x, /) 逐元素测试是否为正或负无穷大。
isnan(x, /) 逐元素测试是否为 NaN,并返回布尔数组结果。
isneginf(x, /[, out]) 逐元素测试是否为负无穷大,返回布尔数组结果。
isposinf(x, /[, out]) 逐元素测试是否为正无穷大,返回布尔数组结果。
isreal(x) 返回一个布尔数组,如果输入元素是实数则为 True。
isrealobj(x) 如果 x 是非复数类型或复数数组,则返回 True。
isscalar(element) 如果 element 的类型是标量类型,则返回 True。
issubdtype(arg1, arg2) 如果第一个参数在类型层次结构中低于或等于第二个参数的类型码,则返回 True。
iterable(y) 检查对象是否可迭代。
ix_(*args) 从 N 个一维序列返回多维网格(开放网格)。
kaiser(M, beta) 返回 Kaiser 窗口。
kron(a, b) 两个数组的 Kronecker 乘积。
lcm(x1, x2) 返回 `
ldexp(x1, x2, /) 返回 x1 * 2**x2,逐元素操作。
left_shift(x1, x2, /) 将整数的位左移。
less(x1, x2, /) 逐元素返回 (x1 < x2) 的真值。
less_equal(x1, x2, /) 逐元素返回 (x1 <= x2) 的真值。
lexsort(keys[, axis]) 使用一系列键执行间接稳定排序。
linspace() 返回指定间隔内的均匀间隔数字。
load(*args, **kwargs) .npy.npz 或 pickled 文件中加载数组或序列化对象。
log(x, /) 自然对数,逐元素操作。
log10(x, /) 返回输入数组的以 10 为底的对数,逐元素操作。
log1p(x, /) 返回输入数组加 1 的自然对数,逐元素操作。
log2(x, /) x 的以 2 为底的对数,逐元素操作。
logaddexp 输入指数的对数之和。
logaddexp2 以 2 为底的指数输入的对数之和。
logical_and(*args) 逐元素计算 x1 AND x2 的真值。
logical_not(*args) 逐元素计算 NOT x 的真值。
logical_or(*args) 逐元素计算 x1 OR x2 的真值。
logical_xor(*args) 逐元素计算 x1 XOR x2 的真值。
logspace(start, stop[, num, endpoint, base, …]) 返回对数刻度上均匀分布的数字。
mask_indices(*args, **kwargs) 给定掩码函数,返回访问 (n, n) 数组的索引。
matmul(a, b, *[, precision, …]) 执行矩阵乘法。
matrix_transpose(x, /) 转置数组的最后两个维度。
max(a[, axis, out, keepdims, initial, where]) 返回数组或沿轴的最大值。
maximum(x1, x2, /) 逐元素计算数组元素的最大值。
mean(a[, axis, dtype, out, keepdims, where]) 沿指定轴计算算术平均值。
median(a[, axis, out, overwrite_input, keepdims]) 沿指定轴计算中位数。
meshgrid(*xi[, copy, sparse, indexing]) 从坐标向量返回坐标矩阵的元组。
mgrid 返回密集的多维网格。
min(a[, axis, out, keepdims, initial, where]) 返回数组或沿轴的最小值。
minimum(x1, x2, /) 逐元素计算数组元素的最小值。
mod(x1, x2, /) 返回除法的元素余数。
modf(x, /[, out]) 返回数组元素的整数部分和小数部分。
moveaxis(a, source, destination) 将数组轴移动到新位置
multiply(x1, x2, /) 对参数逐元素相乘。
nan_to_num(x[, copy, nan, posinf, neginf]) 将 NaN 替换为零,将无穷大替换为大的有限数(默认
nanargmax(a[, axis, out, keepdims]) 返回忽略指定轴上的 NaN 的最大值的索引
nanargmin(a[, axis, out, keepdims]) 返回忽略指定轴上的 NaN 的最小值的索引
nancumprod(a[, axis, dtype, out]) 返回沿指定轴对数组元素的累积积,处理 NaN 为
nancumsum(a[, axis, dtype, out]) 返回沿指定轴对数组元素的累积和,处理 NaN 为
nanmax(a[, axis, out, keepdims, initial, where]) 返回数组或指定轴上的最大值,忽略任何 NaN
nanmean(a[, axis, dtype, out, keepdims, where]) 计算沿指定轴的算术平均值,忽略 NaN
nanmedian(a[, axis, out, overwrite_input, …]) 计算沿指定轴的中位数,忽略 NaN
nanmin(a[, axis, out, keepdims, initial, where]) 返回数组或指定轴上的最小值,忽略任何 NaN
nanpercentile(a, q[, axis, out, …]) 计算沿指定轴的数据的第 q 分位数,
nanprod(a[, axis, dtype, out, keepdims, …]) 返回沿指定轴对数组元素求积,处理 NaN 为
nanquantile(a, q[, axis, out, …]) 计算沿指定轴的数据的第 q 分位数,
nanstd(a[, axis, dtype, out, ddof, …]) 计算沿指定轴的标准差,忽略 NaN
nansum(a[, axis, dtype, out, keepdims, …]) 返回沿指定轴对数组元素求和,处理 NaN 为
nanvar(a[, axis, dtype, out, ddof, …]) 计算沿指定轴的方差,忽略 NaN
ndarray Array 的别名。
ndim(a) 返回数组的维数。
negative(x, /) 数值取反,逐元素操作。
nextafter(x1, x2, /) 返回 x1 朝向 x2 的下一个浮点数值,逐元素操作。
nonzero(a, *[, size, fill_value]) 返回数组中非零元素的索引。
not_equal(x1, x2, /) 逐元素返回 (x1 != x2)。
number() 所有数值标量类型的抽象基类。
object_ 任何 Python 对象。
ogrid 返回开放多维“网格”。
ones(shape[, dtype, device]) 返回给定形状和类型的新数组,填充为 1。
ones_like(a[, dtype, shape, device]) 返回与给定数组具有相同形状和类型的填充为 1 的数组。
outer(a, b[, out]) 计算两个向量的外积。
packbits(a[, axis, bitorder]) 将二值数组的元素打包为 uint8 数组中的位。
pad(array, pad_width[, mode]) 对数组进行填充。
partition(a, kth[, axis]) 返回数组的部分排序副本。
percentile(a, q[, axis, out, …]) 计算沿指定轴的数据的第 q 个百分位数。
permute_dims(a, /, axes) 返回通过转置轴的数组。
piecewise(x, condlist, funclist, *args, **kw) 计算分段定义的函数。
place(arr, mask, vals, *[, inplace]) 根据条件和输入值改变数组的元素。
poly(seq_of_zeros) 根据给定的根序列找到多项式的系数。
polyadd(a1, a2) 计算两个多项式的和。
polyder(p[, m]) 返回多项式指定阶数的导数。
polydiv(u, v, *[, trim_leading_zeros]) 返回多项式除法的商和余数。
polyfit(x, y, deg[, rcond, full, w, cov]) 最小二乘多项式拟合。
polyint(p[, m, k]) 返回多项式的不定积分(反导数)。
polymul(a1, a2, *[, trim_leading_zeros]) 计算两个多项式的乘积。
polysub(a1, a2) 两个多项式的差(减法)。
polyval(p, x, *[, unroll]) 在特定值处计算多项式的值。
positive(x, /) 数值的正值,逐元素操作。
pow(x1, x2, /) 将第一个数组元素按第二个数组元素的幂进行元素级操作。
power(x1, x2, /) 将第一个数组元素按第二个数组元素的幂进行元素级操作。
printoptions(*args, **kwargs) 设置打印选项的上下文管理器。
prod(a[, axis, dtype, out, keepdims, …]) 返回给定轴上数组元素的乘积。
promote_types(a, b) 返回二进制操作应将其参数转换为的类型。
ptp(a[, axis, out, keepdims]) 沿某个轴的值范围(最大值 - 最小值)。
put(a, ind, v[, mode, inplace]) 用给定值替换数组的指定元素。
quantile(a, q[, axis, out, overwrite_input, …]) 计算沿指定轴的数据的第 q 个分位数。
r_ 沿第一个轴连接切片、标量和类数组对象。
rad2deg(x, /) 将角度从弧度转换为度。
radians(x, /) 将角度从度转换为弧度。
ravel(a[, order]) 将数组展平为一维形状。
ravel_multi_index(multi_index, dims[, mode, …]) 将多维索引转换为平坦索引。
real(val, /) 返回复数参数的实部。
reciprocal(x, /) 返回参数的倒数,逐元素操作。
remainder(x1, x2, /) 返回除法的元素级余数。
repeat(a, repeats[, axis, total_repeat_length]) 将数组中每个元素重复指定次数。
reshape(a[, shape, order, newshape]) 返回数组的重塑副本。
resize(a, new_shape) 返回具有指定形状的新数组。
result_type(*args) 返回应用于 NumPy 的结果类型。
right_shift(x1, x2, /) x1 的位向右移动到指定的 x2 量。
rint(x, /) 将数组元素四舍五入到最接近的整数。
roll(a, shift[, axis]) 沿指定轴滚动数组元素。
rollaxis(a, axis[, start]) 将指定的轴滚动到给定位置。
roots(p, *[, strip_zeros]) 返回具有给定系数的多项式的根。
rot90(m[, k, axes]) 在由轴指定的平面中将数组旋转 90 度。
round(a[, decimals, out]) 将数组四舍五入到指定的小数位数。
round_(a[, decimals, out]) 将数组四舍五入到指定的小数位数。
s_ 用于构建数组索引元组的更好方式。
save(file, arr[, allow_pickle, fix_imports]) 将数组以 NumPy .npy 格式保存到二进制文件中。
savez(file, *args, **kwds) 以未压缩的 .npz 格式将多个数组保存到单个文件中。
searchsorted(a, v[, side, sorter, method]) 在排序数组内执行二分搜索。
select(condlist, choicelist[, default]) 根据条件从 choicelist 中选择元素返回数组。
set_printoptions([precision, threshold, …]) 设置打印选项。
setdiff1d(ar1, ar2[, assume_unique, size, …]) 计算两个一维数组的差集。
setxor1d(ar1, ar2[, assume_unique]) 计算两个数组中元素的异或。
shape(a) 返回数组的形状。
sign(x, /) 返回数的元素级别符号指示。
signbit(x, /) 返回元素级别的 True,其中设置了符号位(小于零)。
signedinteger() 所有有符号整数标量类型的抽象基类。
sin(x, /) 按元素计算三角正弦。
sinc(x, /) 返回归一化的 sinc 函数。
single float32 的别名。
sinh(x, /) 按元素计算双曲正弦。
size(a[, axis]) 返回给定轴上的元素数量。
sort(a[, axis, kind, order, stable, descending]) 返回数组的排序副本。
sort_complex(a) 使用实部先排序复杂数组,然后按虚部排序。
split(ary, indices_or_sections[, axis]) 将数组拆分为多个子数组,作为 ary 的视图。
sqrt(x, /) 返回数组元素的非负平方根。
square(x, /) 返回输入数组的按元素平方。
squeeze(a[, axis]) 从数组中移除一个或多个长度为 1 的轴。
stack(arrays[, axis, out, dtype]) 沿新轴连接序列的数组。
std(a[, axis, dtype, out, ddof, keepdims, …]) 沿指定轴计算标准差。
subtract(x1, x2, /) 逐元素地进行减法运算。
sum(a[, axis, dtype, out, keepdims, …]) 沿给定轴对数组元素求和。
swapaxes(a, axis1, axis2) 交换数组的两个轴。
take(a, indices[, axis, out, mode, …]) 从数组中取出元素。
take_along_axis(arr, indices, axis[, mode, …]) 从数组中取出元素。
tan(x, /) 计算元素的正切。
tanh(x, /) 计算元素的双曲正切。
tensordot(a, b[, axes, precision, …]) 计算两个 N 维数组的张量点积。
tile(A, reps) 通过重复 A 指定的次数构造一个数组。
trace(a[, offset, axis1, axis2, dtype, out]) 返回数组的对角线之和。
trapezoid(y[, x, dx, axis]) 使用复合梯形规则沿指定轴积分。
transpose(a[, axes]) 返回 N 维数组的转置版本。
tri(N[, M, k, dtype]) 一个在给定对角线及其以下位置为 1,其他位置为 0 的数组。
tril(m[, k]) 数组的下三角形。
tril_indices(n[, k, m]) 返回(n, m)数组的下三角形的索引。
tril_indices_from(arr[, k]) 返回数组 arr 的下三角形的索引。
trim_zeros(filt[, trim]) 从一维数组或序列中修剪前导和/或尾随的零。
triu(m[, k]) 数组的上三角形。
triu_indices(n[, k, m]) 返回(n, m)数组的上三角形的索引。
triu_indices_from(arr[, k]) 返回数组 arr 的上三角形的索引。
true_divide(x1, x2, /) 逐元素地进行除法运算。
trunc(x) 返回输入元素的截断值。
ufunc(func, /, nin, nout, *[, name, nargs, …]) 在整个数组上逐元素操作的函数。
uint uint64的别名。
uint16(x)
uint32(x)
uint64(x)
uint8(x)
union1d(ar1, ar2, *[, size, fill_value]) 计算两个 1D 数组的并集。
unique(ar[, return_index, return_inverse, …]) 返回数组中的唯一值。
unique_all(x, /, *[, size, fill_value]) 返回 x 的唯一值以及索引、逆索引和计数。
unique_counts(x, /, *[, size, fill_value]) 返回 x 的唯一值及其计数。
unique_inverse(x, /, *[, size, fill_value]) 返回 x 的唯一值以及索引、逆索引和计数。
unique_values(x, /, *[, size, fill_value]) 返回 x 的唯一值以及索引、逆索引和计数。
unpackbits(a[, axis, count, bitorder]) 将 uint8 数组的元素解包为二进制值输出数组。
unravel_index(indices, shape) 将扁平索引转换为多维索引。
unstack(x, /, *[, axis])
unsignedinteger() 所有无符号整数标量类型的抽象基类。
unwrap(p[, discont, axis, period]) 通过取周期的补集来展开数组。
vander(x[, N, increasing]) 生成范德蒙矩阵。
var(a[, axis, dtype, out, ddof, keepdims, …]) 计算沿指定轴的方差。
vdot(a, b, *[, precision, …]) 执行两个 1D 向量的共轭乘法。
vecdot(x1, x2, /, *[, axis, precision, …]) 执行两个批量向量的共轭乘法。
vectorize(pyfunc, *[, excluded, signature]) 定义一个具有广播功能的向量化函数。
vsplit(ary, indices_or_sections) 按垂直(行)方向将数组分割成多个子数组。
vstack(tup[, dtype]) 沿垂直(行)方向堆叠数组序列。
where() 根据条件从两个数组中选择元素。
zeros(shape[, dtype, device]) 返回一个给定形状和类型的全零数组。
zeros_like(a[, dtype, shape, device]) 返回与给定数组相同形状和类型的全零数组。

jax.numpy.fft

fft(a[, n, axis, norm]) 计算一维离散傅里叶变换。
fft2(a[, s, axes, norm]) 计算二维离散傅里叶变换。
fftfreq(n[, d, dtype]) 返回离散傅里叶变换的样本频率。
fftn(a[, s, axes, norm]) 计算 N 维离散傅里叶变换。
fftshift(x[, axes]) 将零频率分量移动到频谱中心。
hfft(a[, n, axis, norm]) 计算具有 Hermitian 对称性的信号的 FFT。
ifft(a[, n, axis, norm]) 计算一维离散傅里叶逆变换。
ifft2(a[, s, axes, norm]) 计算二维离散傅里叶逆变换。
ifftn(a[, s, axes, norm]) 计算 N 维离散傅里叶逆变换。
ifftshift(x[, axes]) fftshift 的逆操作。
ihfft(a[, n, axis, norm]) 计算具有 Hermitian 对称性的信号的逆 FFT。
irfft(a[, n, axis, norm]) 计算 rfft 的逆变换。
irfft2(a[, s, axes, norm]) 计算 rfft2 的逆变换。
irfftn(a[, s, axes, norm]) 计算 rfftn 的逆变换。
rfft(a[, n, axis, norm]) 计算一维实数输入的离散傅里叶变换。
rfft2(a[, s, axes, norm]) 计算实数组的二维 FFT。
rfftfreq(n[, d, dtype]) 返回离散傅里叶变换的样本频率。

| rfftn(a[, s, axes, norm]) | 计算实数输入的 N 维离散傅里叶变换。 | ## jax.numpy.linalg

cholesky(a, *[, upper]) 计算矩阵的 Cholesky 分解。
cond(x[, p]) 计算矩阵的条件数。
cross(x1, x2, /, *[, axis]) 计算两个 3D 向量的叉乘。
det 计算数组的行列式。
diagonal(x, /, *[, offset]) 提取矩阵或矩阵堆栈的对角线元素。
eig(a) 计算方阵的特征值和特征向量。
eigh(a[, UPLO, symmetrize_input]) 计算 Hermitian 矩阵的特征值和特征向量。
eigvals(a) 计算一般矩阵的特征值。
eigvalsh(a[, UPLO]) 计算 Hermitian 矩阵的特征值。
inv(a) 返回方阵的逆。
lstsq(a, b[, rcond, numpy_resid]) 返回线性方程组的最小二乘解。
matmul(x1, x2, /, *[, precision, …]) 执行矩阵乘法。
matrix_norm(x, /, *[, keepdims, ord]) 计算矩阵或矩阵堆栈的范数。
matrix_power(a, n) 将方阵提升到整数幂。
matrix_rank(M[, rtol, tol]) 计算矩阵的秩。
matrix_transpose(x, /) 转置矩阵或矩阵堆栈。
multi_dot(arrays, *[, precision]) 高效计算数组序列之间的矩阵乘积。
norm(x[, ord, axis, keepdims]) 计算矩阵或向量的范数。
outer(x1, x2, /) 计算两个一维数组的外积。
pinv(a[, rtol, hermitian, rcond]) 计算(Moore-Penrose)伪逆。
qr() 计算数组的 QR 分解。
slogdet(a, *[, method]) 计算数组行列式的符号和(自然)对数。
solve(a, b) 解线性方程组。
svd() 计算奇异值分解。
svdvals(x, /) 计算矩阵的奇异值。
tensordot(x1, x2, /, *[, axes, precision, …]) 计算两个 N 维数组的张量点积。
tensorinv(a[, ind]) 计算数组的张量逆。
tensorsolve(a, b[, axes]) 解张量方程 a x = b 以得到 x。
trace(x, /, *[, offset, dtype]) 计算矩阵的迹。
vector_norm(x, /, *[, axis, keepdims, ord]) 计算向量或向量批次的范数。
vecdot(x1, x2, /, *[, axis, precision, …]) 计算(批量)向量共轭点积。

JAX Array

JAX Array(以及其别名 jax.numpy.ndarray)是 JAX 中的核心数组对象:您可以将其视为 JAX 中与numpy.ndarray 等效的对象。与 numpy.ndarray 一样,大多数用户不需要手动实例化 Array 对象,而是通过 jax.numpy 函数如 array()arange()linspace() 和上面列出的其他函数来创建它们。

复制和序列化

JAX Array对象设计为在适当的情况下与 Python 标准库工具无缝配合。

使用内置copy模块时,当copy.copy()copy.deepcopy()遇到Array时,等效于调用copy()方法,该方法将在与原始数组相同设备上创建缓冲区的副本。在追踪/JIT 编译的代码中,这将正确工作,尽管在此上下文中,复制操作可能会被编译器省略。

当内置pickle模块遇到Array时,它将通过紧凑的位表示方式对其进行序列化,类似于对numpy.ndarray对象的处理。解封后,结果将是一个新的Array对象在默认设备上。这是因为通常情况下,pickling 和 unpickling 可能发生在不同的运行环境中,并且没有通用的方法将一个运行时环境的设备 ID 映射到另一个的设备 ID。如果在追踪/JIT 编译的代码中使用pickle,将导致ConcretizationTypeError


JAX 中文文档(十三)(5)https://developer.aliyun.com/article/1559746

相关文章
|
3月前
|
算法 Serverless 索引
JAX 中文文档(十三)(5)
JAX 中文文档(十三)
23 1
|
3月前
|
缓存 TensorFlow 算法框架/工具
JAX 中文文档(十三)(3)
JAX 中文文档(十三)
58 1
|
3月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
26 3
|
3月前
|
存储 API 索引
JAX 中文文档(十五)(5)
JAX 中文文档(十五)
34 3
|
3月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
34 2
|
3月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
14 1
|
3月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
23 1
|
3月前
|
机器学习/深度学习 Shell API
JAX 中文文档(十三)(1)
JAX 中文文档(十三)
31 0
|
3月前
|
测试技术 API 调度
JAX 中文文档(十三)(2)
JAX 中文文档(十三)
24 0
|
3月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(2)
JAX 中文文档(十五)
23 0