JAX 中文文档(十五)(5)

简介: JAX 中文文档(十五)

JAX 中文文档(十五)(4)https://developer.aliyun.com/article/1559771

稀疏 API 参考

sparsify(f[, use_tracer]) 实验性稀疏化转换。
grad(fun[, argnums, has_aux]) jax.grad() 的稀疏版本
value_and_grad(fun[, argnums, has_aux]) jax.value_and_grad() 的稀疏版本
empty(shape[, dtype, index_dtype, sparse_format]) 创建空稀疏数组。
eye(N[, M, k, dtype, index_dtype, sparse_format]) 创建二维稀疏单位矩阵。
todense(arr) 将输入转换为密集矩阵。
random_bcoo(key, shape, *[, dtype, …]) 生成随机 BCOO 矩阵。
JAXSparse(args, *, shape) 高级 JAX 稀疏对象的基类。

BCOO 数据结构

BCOOBatched COO format,是在 jax.experimental.sparse 中实现的主要稀疏数据结构。其操作与 JAX 的核心转换兼容,包括批处理(例如 jax.vmap())和自动微分(例如 jax.grad())。

BCOO(args, *, shape[, indices_sorted, …]) 在 JAX 中实现的实验性批量 COO 矩阵
bcoo_broadcast_in_dim(mat, *, shape, …) 通过复制数据来扩展 BCOO 数组的大小和秩。
bcoo_concatenate(operands, *, dimension) jax.lax.concatenate() 的稀疏实现
bcoo_dot_general(lhs, rhs, *, dimension_numbers) 一般的收缩操作。
bcoo_dot_general_sampled(A, B, indices, *, …) 在给定稀疏索引处计算输出的收缩操作。
bcoo_dynamic_slice(mat, start_indices, …) jax.lax.dynamic_slice 的稀疏实现。
bcoo_extract(sparr, arr, *[, assume_unique]) 根据稀疏数组的索引从密集数组中提取值。
bcoo_fromdense(mat, *[, nse, n_batch, …]) 从密集矩阵创建 BCOO 格式的稀疏矩阵。
bcoo_gather(operand, start_indices, …[, …]) lax.gather 的 BCOO 版本。
bcoo_multiply_dense(sp_mat, v) 稀疏数组和密集数组的逐元素乘法。
bcoo_multiply_sparse(lhs, rhs) 两个稀疏数组的逐元素乘法。
bcoo_update_layout(mat, *[, n_batch, …]) 更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。
bcoo_reduce_sum(mat, *, axes) 在给定轴上对数组元素求和。
bcoo_reshape(mat, *, new_sizes[, dimensions]) {func}jax.lax.reshape的稀疏实现。
bcoo_slice(mat, *, start_indices, limit_indices) {func}jax.lax.slice的稀疏实现。
bcoo_sort_indices(mat) 对 BCOO 数组的索引进行排序。
bcoo_squeeze(arr, *, dimensions) {func}jax.lax.squeeze的稀疏实现。
bcoo_sum_duplicates(mat[, nse]) 对 BCOO 数组中的重复索引求和,返回一个排序后的索引数组。
bcoo_todense(mat) 将批量稀疏矩阵转换为密集矩阵。
bcoo_transpose(mat, *, permutation) 转置 BCOO 格式的数组。

BCSR 数据结构

BCSR批量压缩稀疏行格式,正在开发中。其操作与 JAX 的核心转换兼容,包括批处理(如jax.vmap())和自动微分(如jax.grad())。

BCSR(args, *, shape[, indices_sorted, …]) 在 JAX 中实现的实验性批量 CSR 矩阵。
bcsr_dot_general(lhs, rhs, *, dimension_numbers) 通用收缩运算。
bcsr_extract(indices, indptr, mat) 从给定的 BCSR(indices, indptr)处的密集矩阵中提取值。
bcsr_fromdense(mat, *[, nse, n_batch, …]) 从密集矩阵创建 BCSR 格式的稀疏矩阵。
bcsr_todense(mat) 将批量稀疏矩阵转换为密集矩阵。

其他稀疏数据结构

其他稀疏数据结构包括COOCSRCSC。这些是简单稀疏结构的参考实现,具有少数核心操作。它们的操作通常与自动微分转换(如jax.grad())兼容,但不与批处理转换(如jax.vmap())兼容。

COO(args, *, shape[, rows_sorted, cols_sorted]) 在 JAX 中实现的实验性 COO 矩阵。
CSC(args, *, shape) 在 JAX 中实现的实验性 CSC 矩阵;API 可能会更改。
CSR(args, *, shape) 在 JAX 中实现的实验性 CSR 矩阵。
coo_fromdense(mat, *[, nse, index_dtype]) 从密集矩阵创建 COO 格式的稀疏矩阵。
coo_matmat(mat, B, *[, transpose]) COO 稀疏矩阵与密集矩阵的乘积。
coo_matvec(mat, v[, transpose]) COO 稀疏矩阵与密集向量的乘积。
coo_todense(mat) 将 COO 格式的稀疏矩阵转换为密集矩阵。
csr_fromdense(mat, *[, nse, index_dtype]) 从密集矩阵创建 CSR 格式的稀疏矩阵。
csr_matmat(mat, B, *[, transpose]) CSR 稀疏矩阵与密集矩阵的乘积。
csr_matvec(mat, v[, transpose]) CSR 稀疏矩阵与密集向量的乘积。
csr_todense(mat) 将 CSR 格式的稀疏矩阵转换为密集矩阵。

jax.experimental.sparse.linalg

稀疏线性代数例程。

spsolve(data, indices, indptr, b[, tol, reorder]) 使用 QR 分解的稀疏直接求解器。
lobpcg_standard(A, X[, m, tol]) 使用 LOBPCG 例程计算前 k 个标准特征值。

jax.experimental.sparse.BCOO

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.BCOO.html

class jax.experimental.sparse.BCOO(args, *, shape, indices_sorted=False, unique_indices=False)

在 JAX 中实现的实验性批量 COO 矩阵

参数:

  • **(**data – 批量 COO 格式中的数据和索引。
  • indices**)** – 批量 COO 格式中的数据和索引。
  • shape (tuple[int, …**]) – 稀疏数组的形状。
  • args (tuple[Array,* Array]*)
  • indices_sorted (bool)
  • unique_indices (bool)
data

形状为[*batch_dims, nse, *dense_dims]的 ndarray,包含稀疏矩阵中显式存储的数据。

类型:

jax.Array

indices

形状为[*batch_dims, nse, n_sparse]的 ndarray,包含显式存储数据的索引。重复的条目将被求和。

类型:

jax.Array

示例

从稠密数组创建稀疏数组:

>>> M = jnp.array([[0., 2., 0.], [1., 0., 4.]])
>>> M_sp = BCOO.fromdense(M)
>>> M_sp
BCOO(float32[2, 3], nse=3) 

检查内部表示:

>>> M_sp.data
Array([2., 1., 4.], dtype=float32)
>>> M_sp.indices
Array([[0, 1],
 [1, 0],
 [1, 2]], dtype=int32) 

从稀疏数组创建稠密数组:

>>> M_sp.todense()
Array([[0., 2., 0.],
 [1., 0., 4.]], dtype=float32) 

从 COO 数据和索引创建稀疏数组:

>>> data = jnp.array([1., 3., 5.])
>>> indices = jnp.array([[0, 0],
...                      [1, 1],
...                      [2, 2]])
>>> mat = BCOO((data, indices), shape=(3, 3))
>>> mat
BCOO(float32[3, 3], nse=3)
>>> mat.todense()
Array([[1., 0., 0.],
 [0., 3., 0.],
 [0., 0., 5.]], dtype=float32) 
__init__(args, *, shape, indices_sorted=False, unique_indices=False)

参数:

方法

__init__(args, *, shape[, indices_sorted, …])
astype(*args, **kwargs) 复制数组并转换为指定的 dtype。
block_until_ready()
from_scipy_sparse(mat, *[, index_dtype, …]) scipy.sparse数组创建 BCOO 数组。
fromdense(mat, *[, nse, index_dtype, …]) 从(稠密)Array创建 BCOO 数组。
reshape(*args, **kwargs) 返回具有新形状的相同数据的数组。
sort_indices() 返回索引排序后的矩阵副本。
sum(*args, **kwargs) 沿轴求和数组。
sum_duplicates([nse, remove_zeros]) 返回重复索引求和后的数组副本。
todense() 创建数组的稠密版本。
transpose([axes]) 创建包含转置的新数组。
tree_flatten()
tree_unflatten(aux_data, children)
update_layout(*[, n_batch, n_dense, …]) 更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。

属性

T
dtype
n_batch
n_dense
n_sparse
ndim
nse
size
data
indices
shape
indices_sorted
unique_indices

jax.experimental.sparse.bcoo_broadcast_in_dim

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_broadcast_in_dim.html

jax.experimental.sparse.bcoo_broadcast_in_dim(mat, *, shape, broadcast_dimensions)

通过复制数据扩展 BCOO 数组的大小和秩。

BCOO 相当于 jax.lax.broadcast_in_dim。

参数:

  • matBCOO) – BCOO 格式的数组。
  • shapetuple[int,* ]*) – 目标数组的形状。
  • broadcast_dimensionsSequence[int]) – 目标数组形状的维度,每个操作数(mat)形状对应一个维度。

返回:

包含目标数组的 BCOO 格式数组。

返回类型:

BCOO

jax.experimental.sparse.bcoo_concatenate

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_concatenate.html

jax.experimental.sparse.bcoo_concatenate(operands, *, dimension)

稀疏实现的jax.lax.concatenate()函数

参数:

  • operandsSequence[BCOO]) – 要连接的 BCOO 数组序列。这些数组必须具有相同的形状,除了在维度轴上。此外,这些数组必须具有等效的批处理、稀疏和密集维度。
  • dimensionint) – 指定沿其连接数组的维度的正整数。维度必须是输入的批处理或稀疏维度之一;不支持沿密集维度的连接。

返回值:

包含输入数组连接的 BCOO 数组。

返回类型:

BCOO

jax.experimental.sparse.bcoo_dot_general

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_dot_general.html

jax.experimental.sparse.bcoo_dot_general(lhs, rhs, *, dimension_numbers, precision=None, preferred_element_type=None)

一般的收缩操作。

参数:

  • lhsBCOO | Array) – 一个 ndarray 或 BCOO 格式的稀疏数组。
  • rhsBCOO | Array) – 一个 ndarray 或 BCOO 格式的稀疏数组。
  • dimension_numberstuple[tuple[Sequence[int]**, Sequence[int]], tuple[Sequence[int]**, Sequence[int]]]) – 一个形如((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))的元组的元组。
  • precisionNone) – 未使用
  • preferred_element_typeNone) – 未使用

返回:

一个包含结果的 ndarray 或 BCOO 格式的稀疏数组。如果两个输入都是稀疏的,结果将是稀疏的,类型为 BCOO。如果任一输入是密集的,结果将是密集的,类型为 ndarray。

返回类型:

BCOO | Array

jax.experimental.sparse.bcoo_dot_general_sampled

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_dot_general_sampled.html

jax.experimental.sparse.bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers)

给定稀疏索引处计算输出的收缩操作。

参数:

  • lhs – 一个 ndarray。
  • rhs – 一个 ndarray。
  • indicesArray) – BCOO 索引。
  • dimension_numberstuple[tuple[Sequence[int]**, Sequence[int]], tuple[Sequence[int]**, Sequence[int]]]) – 形式为 ((lhs 收缩维度,rhs 收缩维度),(lhs 批次维度,rhs 批次维度)) 的元组的元组。
  • AArray
  • BArray

返回:

BCOO 数据,包含结果的 ndarray。

返回类型:

Array

jax.experimental.sparse.bcoo_dynamic_slice

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_dynamic_slice.html

jax.experimental.sparse.bcoo_dynamic_slice(mat, start_indices, slice_sizes)

{func}jax.lax.dynamic_slice的稀疏实现。

参数:

  • mat (BCOO) – 要切片的 BCOO 数组。
  • start_indices (Sequence[Any]) – 每个维度的标量索引列表。这些值可能是动态的。
  • slice_sizes (Sequence[int]) – 切片的大小。必须是非负整数序列,长度等于操作数的维度数。在 JIT 编译的函数内部,仅支持静态值(所有 JAX 数组在 JIT 内必须具有静态已知大小)。

返回:

包含切片的 BCOO 数组。

返回类型:

out

jax.experimental.sparse.bcoo_extract

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_extract.html

jax.experimental.sparse.bcoo_extract(sparr, arr, *, assume_unique=None)

根据稀疏数组的索引从密集数组中提取值。

参数:

  • sparr (BCOO) – 用于输出的 BCOO 数组的索引。
  • arr (jax.typing.ArrayLike) – 形状与 self.shape 相同的 ArrayLike
  • assume_unique (bool | None) – 布尔值,默认为 sparr.unique_indices。如果为 True,则提取每个索引的值,即使索引包含重复项。如果为 False,则重复的索引将其值求和,并返回第一个索引的位置。

返回:

一个具有与 self 相同稀疏模式的 BCOO 数组。

返回类型:

提取的结果

jax.experimental.sparse.bcoo_fromdense

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_fromdense.html

jax.experimental.sparse.bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=<class 'jax.numpy.int32'>)

从密集矩阵创建 BCOO 格式的稀疏矩阵。

参数:

  • matArray)– 要转换为 BCOO 格式的数组。
  • nseint | None)– 每个批次中指定元素的数量
  • n_batchint)– 批次维度的数量(默认:0)
  • n_denseint)– 块维度的数量(默认:0)
  • index_dtypejax.typing.DTypeLike)– 稀疏索引的数据类型(默认:int32)

返回:

矩阵的 BCOO 表示。

返回类型:

mat_bcoo

jax.experimental.sparse.bcoo_gather

原文:jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_gather.html

jax.experimental.sparse.bcoo_gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)

BCOO 版本的 lax.gather。

参数:

  • operand (BCOO)
  • start_indices (数组)
  • dimension_numbers (GatherDimensionNumbers)
  • slice_sizes (tuple[int, …**])
  • unique_indices (bool)
  • indices_are_sorted (bool)
  • mode (str | GatherScatterMode | None)

返回类型:

BCOO

相关文章
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(4)
JAX 中文文档(十五)
40 3
|
4月前
|
机器学习/深度学习 数据可视化 编译器
JAX 中文文档(十四)(5)
JAX 中文文档(十四)
51 2
|
4月前
JAX 中文文档(十一)(5)
JAX 中文文档(十一)
17 1
|
4月前
|
API 异构计算 Python
JAX 中文文档(十一)(4)
JAX 中文文档(十一)
33 1
|
4月前
|
算法 API 开发工具
JAX 中文文档(十二)(5)
JAX 中文文档(十二)
55 1
|
4月前
|
安全 API 网络架构
JAX 中文文档(十五)(1)
JAX 中文文档(十五)
51 0
|
4月前
|
TensorFlow API 算法框架/工具
JAX 中文文档(十五)(3)
JAX 中文文档(十五)
29 0
|
4月前
|
机器学习/深度学习 存储 API
JAX 中文文档(十五)(2)
JAX 中文文档(十五)
36 0
|
4月前
|
Python
JAX 中文文档(十四)(4)
JAX 中文文档(十四)
32 0
|
4月前
|
关系型数据库
JAX 中文文档(十四)(1)
JAX 中文文档(十四)
33 0