JAX 中文文档(十五)(4)https://developer.aliyun.com/article/1559771
稀疏 API 参考
sparsify (f[, use_tracer]) |
实验性稀疏化转换。 |
grad (fun[, argnums, has_aux]) |
jax.grad() 的稀疏版本 |
value_and_grad (fun[, argnums, has_aux]) |
jax.value_and_grad() 的稀疏版本 |
empty (shape[, dtype, index_dtype, sparse_format]) |
创建空稀疏数组。 |
eye (N[, M, k, dtype, index_dtype, sparse_format]) |
创建二维稀疏单位矩阵。 |
todense (arr) |
将输入转换为密集矩阵。 |
random_bcoo (key, shape, *[, dtype, …]) |
生成随机 BCOO 矩阵。 |
JAXSparse (args, *, shape) |
高级 JAX 稀疏对象的基类。 |
BCOO 数据结构
BCOO
是 Batched COO format,是在 jax.experimental.sparse
中实现的主要稀疏数据结构。其操作与 JAX 的核心转换兼容,包括批处理(例如 jax.vmap()
)和自动微分(例如 jax.grad()
)。
BCOO (args, *, shape[, indices_sorted, …]) |
在 JAX 中实现的实验性批量 COO 矩阵 |
bcoo_broadcast_in_dim (mat, *, shape, …) |
通过复制数据来扩展 BCOO 数组的大小和秩。 |
bcoo_concatenate (operands, *, dimension) |
jax.lax.concatenate() 的稀疏实现 |
bcoo_dot_general (lhs, rhs, *, dimension_numbers) |
一般的收缩操作。 |
bcoo_dot_general_sampled (A, B, indices, *, …) |
在给定稀疏索引处计算输出的收缩操作。 |
bcoo_dynamic_slice (mat, start_indices, …) |
jax.lax.dynamic_slice 的稀疏实现。 |
bcoo_extract (sparr, arr, *[, assume_unique]) |
根据稀疏数组的索引从密集数组中提取值。 |
bcoo_fromdense (mat, *[, nse, n_batch, …]) |
从密集矩阵创建 BCOO 格式的稀疏矩阵。 |
bcoo_gather (operand, start_indices, …[, …]) |
lax.gather 的 BCOO 版本。 |
bcoo_multiply_dense (sp_mat, v) |
稀疏数组和密集数组的逐元素乘法。 |
bcoo_multiply_sparse (lhs, rhs) |
两个稀疏数组的逐元素乘法。 |
bcoo_update_layout (mat, *[, n_batch, …]) |
更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。 |
bcoo_reduce_sum (mat, *, axes) |
在给定轴上对数组元素求和。 |
bcoo_reshape (mat, *, new_sizes[, dimensions]) |
{func}jax.lax.reshape 的稀疏实现。 |
bcoo_slice (mat, *, start_indices, limit_indices) |
{func}jax.lax.slice 的稀疏实现。 |
bcoo_sort_indices (mat) |
对 BCOO 数组的索引进行排序。 |
bcoo_squeeze (arr, *, dimensions) |
{func}jax.lax.squeeze 的稀疏实现。 |
bcoo_sum_duplicates (mat[, nse]) |
对 BCOO 数组中的重复索引求和,返回一个排序后的索引数组。 |
bcoo_todense (mat) |
将批量稀疏矩阵转换为密集矩阵。 |
bcoo_transpose (mat, *, permutation) |
转置 BCOO 格式的数组。 |
BCSR 数据结构
BCSR
是批量压缩稀疏行格式,正在开发中。其操作与 JAX 的核心转换兼容,包括批处理(如jax.vmap()
)和自动微分(如jax.grad()
)。
BCSR (args, *, shape[, indices_sorted, …]) |
在 JAX 中实现的实验性批量 CSR 矩阵。 |
bcsr_dot_general (lhs, rhs, *, dimension_numbers) |
通用收缩运算。 |
bcsr_extract (indices, indptr, mat) |
从给定的 BCSR(indices, indptr)处的密集矩阵中提取值。 |
bcsr_fromdense (mat, *[, nse, n_batch, …]) |
从密集矩阵创建 BCSR 格式的稀疏矩阵。 |
bcsr_todense (mat) |
将批量稀疏矩阵转换为密集矩阵。 |
其他稀疏数据结构
其他稀疏数据结构包括COO
、CSR
和CSC
。这些是简单稀疏结构的参考实现,具有少数核心操作。它们的操作通常与自动微分转换(如jax.grad()
)兼容,但不与批处理转换(如jax.vmap()
)兼容。
COO (args, *, shape[, rows_sorted, cols_sorted]) |
在 JAX 中实现的实验性 COO 矩阵。 |
CSC (args, *, shape) |
在 JAX 中实现的实验性 CSC 矩阵;API 可能会更改。 |
CSR (args, *, shape) |
在 JAX 中实现的实验性 CSR 矩阵。 |
coo_fromdense (mat, *[, nse, index_dtype]) |
从密集矩阵创建 COO 格式的稀疏矩阵。 |
coo_matmat (mat, B, *[, transpose]) |
COO 稀疏矩阵与密集矩阵的乘积。 |
coo_matvec (mat, v[, transpose]) |
COO 稀疏矩阵与密集向量的乘积。 |
coo_todense (mat) |
将 COO 格式的稀疏矩阵转换为密集矩阵。 |
csr_fromdense (mat, *[, nse, index_dtype]) |
从密集矩阵创建 CSR 格式的稀疏矩阵。 |
csr_matmat (mat, B, *[, transpose]) |
CSR 稀疏矩阵与密集矩阵的乘积。 |
csr_matvec (mat, v[, transpose]) |
CSR 稀疏矩阵与密集向量的乘积。 |
csr_todense (mat) |
将 CSR 格式的稀疏矩阵转换为密集矩阵。 |
jax.experimental.sparse.linalg
稀疏线性代数例程。
spsolve (data, indices, indptr, b[, tol, reorder]) |
使用 QR 分解的稀疏直接求解器。 |
lobpcg_standard (A, X[, m, tol]) |
使用 LOBPCG 例程计算前 k 个标准特征值。 |
jax.experimental.sparse.BCOO
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.BCOO.html
class jax.experimental.sparse.BCOO(args, *, shape, indices_sorted=False, unique_indices=False)
在 JAX 中实现的实验性批量 COO 矩阵
参数:
- **(**data – 批量 COO 格式中的数据和索引。
- indices**)** – 批量 COO 格式中的数据和索引。
- shape (tuple[int, …**]) – 稀疏数组的形状。
- args (tuple[Array,* Array]*)
- indices_sorted (bool)
- unique_indices (bool)
data
形状为[*batch_dims, nse, *dense_dims]
的 ndarray,包含稀疏矩阵中显式存储的数据。
类型:
jax.Array
indices
形状为[*batch_dims, nse, n_sparse]
的 ndarray,包含显式存储数据的索引。重复的条目将被求和。
类型:
jax.Array
示例
从稠密数组创建稀疏数组:
>>> M = jnp.array([[0., 2., 0.], [1., 0., 4.]]) >>> M_sp = BCOO.fromdense(M) >>> M_sp BCOO(float32[2, 3], nse=3)
检查内部表示:
>>> M_sp.data Array([2., 1., 4.], dtype=float32) >>> M_sp.indices Array([[0, 1], [1, 0], [1, 2]], dtype=int32)
从稀疏数组创建稠密数组:
>>> M_sp.todense() Array([[0., 2., 0.], [1., 0., 4.]], dtype=float32)
从 COO 数据和索引创建稀疏数组:
>>> data = jnp.array([1., 3., 5.]) >>> indices = jnp.array([[0, 0], ... [1, 1], ... [2, 2]]) >>> mat = BCOO((data, indices), shape=(3, 3)) >>> mat BCOO(float32[3, 3], nse=3) >>> mat.todense() Array([[1., 0., 0.], [0., 3., 0.], [0., 0., 5.]], dtype=float32)
__init__(args, *, shape, indices_sorted=False, unique_indices=False)
参数:
方法
__init__ (args, *, shape[, indices_sorted, …]) |
|
astype (*args, **kwargs) |
复制数组并转换为指定的 dtype。 |
block_until_ready () |
|
from_scipy_sparse (mat, *[, index_dtype, …]) |
从scipy.sparse 数组创建 BCOO 数组。 |
fromdense (mat, *[, nse, index_dtype, …]) |
从(稠密)Array 创建 BCOO 数组。 |
reshape (*args, **kwargs) |
返回具有新形状的相同数据的数组。 |
sort_indices () |
返回索引排序后的矩阵副本。 |
sum (*args, **kwargs) |
沿轴求和数组。 |
sum_duplicates ([nse, remove_zeros]) |
返回重复索引求和后的数组副本。 |
todense () |
创建数组的稠密版本。 |
transpose ([axes]) |
创建包含转置的新数组。 |
tree_flatten () |
|
tree_unflatten (aux_data, children) |
|
update_layout (*[, n_batch, n_dense, …]) |
更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。 |
属性
T |
|
dtype |
|
n_batch |
|
n_dense |
|
n_sparse |
|
ndim |
|
nse |
|
size |
|
data |
|
indices |
|
shape |
|
indices_sorted |
|
unique_indices |
jax.experimental.sparse.bcoo_broadcast_in_dim
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_broadcast_in_dim.html
jax.experimental.sparse.bcoo_broadcast_in_dim(mat, *, shape, broadcast_dimensions)
通过复制数据扩展 BCOO 数组的大小和秩。
BCOO 相当于 jax.lax.broadcast_in_dim。
参数:
- mat(BCOO) – BCOO 格式的数组。
- shape(tuple[int,* …]*) – 目标数组的形状。
- broadcast_dimensions(Sequence[int]) – 目标数组形状的维度,每个操作数(
mat
)形状对应一个维度。
返回:
包含目标数组的 BCOO 格式数组。
返回类型:
BCOO
jax.experimental.sparse.bcoo_concatenate
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_concatenate.html
jax.experimental.sparse.bcoo_concatenate(operands, *, dimension)
稀疏实现的jax.lax.concatenate()
函数
参数:
- operands(Sequence[BCOO]) – 要连接的 BCOO 数组序列。这些数组必须具有相同的形状,除了在维度轴上。此外,这些数组必须具有等效的批处理、稀疏和密集维度。
- dimension(int) – 指定沿其连接数组的维度的正整数。维度必须是输入的批处理或稀疏维度之一;不支持沿密集维度的连接。
返回值:
包含输入数组连接的 BCOO 数组。
返回类型:
BCOO
jax.experimental.sparse.bcoo_dot_general
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_dot_general.html
jax.experimental.sparse.bcoo_dot_general(lhs, rhs, *, dimension_numbers, precision=None, preferred_element_type=None)
一般的收缩操作。
参数:
- lhs(BCOO | Array) – 一个 ndarray 或 BCOO 格式的稀疏数组。
- rhs(BCOO | Array) – 一个 ndarray 或 BCOO 格式的稀疏数组。
- dimension_numbers(tuple[tuple[Sequence[int]**, Sequence[int]], tuple[Sequence[int]**, Sequence[int]]]) – 一个形如((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))的元组的元组。
- precision(None) – 未使用
- preferred_element_type(None) – 未使用
返回:
一个包含结果的 ndarray 或 BCOO 格式的稀疏数组。如果两个输入都是稀疏的,结果将是稀疏的,类型为 BCOO。如果任一输入是密集的,结果将是密集的,类型为 ndarray。
返回类型:
BCOO | Array
jax.experimental.sparse.bcoo_dot_general_sampled
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_dot_general_sampled.html
jax.experimental.sparse.bcoo_dot_general_sampled(A, B, indices, *, dimension_numbers)
给定稀疏索引处计算输出的收缩操作。
参数:
- lhs – 一个 ndarray。
- rhs – 一个 ndarray。
- indices(Array) – BCOO 索引。
- dimension_numbers(tuple[tuple[Sequence[int]**, Sequence[int]], tuple[Sequence[int]**, Sequence[int]]]) – 形式为 ((lhs 收缩维度,rhs 收缩维度),(lhs 批次维度,rhs 批次维度)) 的元组的元组。
- A(Array)
- B(Array)
返回:
BCOO 数据,包含结果的 ndarray。
返回类型:
Array
jax.experimental.sparse.bcoo_dynamic_slice
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_dynamic_slice.html
jax.experimental.sparse.bcoo_dynamic_slice(mat, start_indices, slice_sizes)
{func}jax.lax.dynamic_slice
的稀疏实现。
参数:
- mat (BCOO) – 要切片的 BCOO 数组。
- start_indices (Sequence[Any]) – 每个维度的标量索引列表。这些值可能是动态的。
- slice_sizes (Sequence[int]) – 切片的大小。必须是非负整数序列,长度等于操作数的维度数。在 JIT 编译的函数内部,仅支持静态值(所有 JAX 数组在 JIT 内必须具有静态已知大小)。
返回:
包含切片的 BCOO 数组。
返回类型:
out
jax.experimental.sparse.bcoo_extract
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_extract.html
jax.experimental.sparse.bcoo_extract(sparr, arr, *, assume_unique=None)
根据稀疏数组的索引从密集数组中提取值。
参数:
- sparr (BCOO) – 用于输出的 BCOO 数组的索引。
- arr (jax.typing.ArrayLike) – 形状与 self.shape 相同的 ArrayLike
- assume_unique (bool | None) – 布尔值,默认为 sparr.unique_indices。如果为 True,则提取每个索引的值,即使索引包含重复项。如果为 False,则重复的索引将其值求和,并返回第一个索引的位置。
返回:
一个具有与 self 相同稀疏模式的 BCOO 数组。
返回类型:
提取的结果
jax.experimental.sparse.bcoo_fromdense
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_fromdense.html
jax.experimental.sparse.bcoo_fromdense(mat, *, nse=None, n_batch=0, n_dense=0, index_dtype=<class 'jax.numpy.int32'>)
从密集矩阵创建 BCOO 格式的稀疏矩阵。
参数:
- mat(Array)– 要转换为 BCOO 格式的数组。
- nse(int | None)– 每个批次中指定元素的数量
- n_batch(int)– 批次维度的数量(默认:0)
- n_dense(int)– 块维度的数量(默认:0)
- index_dtype(jax.typing.DTypeLike)– 稀疏索引的数据类型(默认:int32)
返回:
矩阵的 BCOO 表示。
返回类型:
mat_bcoo
jax.experimental.sparse.bcoo_gather
原文:
jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.bcoo_gather.html
jax.experimental.sparse.bcoo_gather(operand, start_indices, dimension_numbers, slice_sizes, *, unique_indices=False, indices_are_sorted=False, mode=None, fill_value=None)
BCOO 版本的 lax.gather。
参数:
- operand (BCOO)
- start_indices (数组)
- dimension_numbers (GatherDimensionNumbers)
- slice_sizes (tuple[int, …**])
- unique_indices (bool)
- indices_are_sorted (bool)
- mode (str | GatherScatterMode | None)
返回类型:
BCOO