jax.experimental.sparse
模块#
注意
jax.experimental.sparse
中的方法是实验性的参考实现,不建议用于性能关键型应用。
模块支持 JAX 中稀疏矩阵操作的实验性功能。该模块正在积极开发中,API 可能会发生变化。主要提供的接口是 jax.experimental.sparse
BCOO
稀疏数组类型,以及 sparsify()
变换。
批处理坐标 (BCOO) 稀疏矩阵#
JAX 中当前可用主要的 L2 稀疏对象是 BCOO
,即批处理坐标稀疏数组。它提供了一种与 JAX 变换兼容的压缩存储格式,特别是 JIT(例如 jax.jit()
)、批处理(例如 jax.vmap()
)和自动微分(例如 jax.grad()
)。
下面是如何从密集数组创建稀疏数组的示例
>>> from jax.experimental import sparse
>>> import jax.numpy as jnp
>>> import numpy as np
>>> M = jnp.array([[0., 1., 0., 2.],
... [3., 0., 0., 0.],
... [0., 0., 4., 0.]])
>>> M_sp = sparse.BCOO.fromdense(M)
>>> M_sp
BCOO(float32[3, 4], nse=4)
使用 `todense()` 方法转换回密集数组
>>> M_sp.todense()
Array([[0., 1., 0., 2.],
[3., 0., 0., 0.],
[0., 0., 4., 0.]], dtype=float32)
BCOO 格式是标准 COO 格式的一种稍作修改的版本,其密集表示可以在 `data` 和 `indices` 属性中看到
>>> M_sp.data # Explicitly stored data
Array([1., 2., 3., 4.], dtype=float32)
>>> M_sp.indices # Indices of the stored data
Array([[0, 1],
[0, 3],
[1, 0],
[2, 2]], dtype=int32)
BCOO 对象具有熟悉的类数组属性,以及特定于稀疏的属性
>>> M_sp.ndim
2
>>> M_sp.shape
(3, 4)
>>> M_sp.dtype
dtype('float32')
>>> M_sp.nse # "number of specified elements"
4
BCOO 对象还实现了许多类数组方法,以便您可以直接在 jax 程序中使用它们。例如,这里我们计算转置矩阵-向量乘积
>>> y = jnp.array([3., 6., 5.])
>>> M_sp.T @ y
Array([18., 3., 20., 6.], dtype=float32)
>>> M.T @ y # Compare to dense version
Array([18., 3., 20., 6.], dtype=float32)
BCOO 对象旨在与 JAX 变换兼容,包括 jax.jit()
、jax.vmap()
、jax.grad()
等。例如
>>> from jax import grad, jit
>>> def f(y):
... return (M_sp.T @ y).sum()
...
>>> jit(grad(f))(y)
Array([3., 3., 4.], dtype=float32)
然而,请注意,在正常情况下,jax.numpy
和 jax.lax
函数无法处理稀疏矩阵,因此尝试计算诸如 `jnp.dot(M_sp.T, y)` 之类的操作将导致错误(然而,请参阅下一节)。
Sparsify 变换#
JAX 稀疏实现的一个总体目标是提供一种无缝切换密集和稀疏计算的方法,而无需修改密集实现。这种稀疏实验通过 `sparsify()` 变换来实现。
考虑此函数,它从矩阵和向量输入计算更复杂的结果
>>> def f(M, v):
... return 2 * jnp.dot(jnp.log1p(M.T), v) + 1
...
>>> f(M, y)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
如果我们将稀疏矩阵直接传递给它,将会导致错误,因为 `jnp` 函数不识别稀疏输入。但是,通过 `sparsify()`,我们可以获得一个接受稀疏矩阵的该函数的版本
>>> f_sp = sparse.sparsify(f)
>>> f_sp(M_sp, y)
Array([17.635532, 5.158883, 17.09438 , 7.591674], dtype=float32)
对 `sparsify()` 的支持包括许多最常见的原始操作,包括
广义(批处理)矩阵乘积和爱因斯坦求和(`dot_general_p`)
零保持的逐元素二元运算(例如 `add_p`、`mul_p` 等)
零保持的逐元素一元运算(例如 `abs_p`、`jax.lax.neg_p` 等)
求和规约(`reduce_sum_p`)
通用索引操作(`slice_p`、`lax.dynamic_slice_p`、`lax.gather_p`)
连接和堆叠(`concatenate_p`)
转置和重塑(`transpose_p`、`reshape_p`、`squeeze_p`、`broadcast_in_dim_p`)
一些高阶函数(`cond_p`、`while_p`、`scan_p`)
一些简单的 1D 卷积(`conv_general_dilated_p`)
几乎任何可以降低到这些支持的原始操作的 `jax.numpy` 函数都可以在 `sparsify` 变换中使用,以对稀疏数组进行操作。这些原始操作集足以实现相对复杂的稀疏工作流,如下一节所示。
示例:稀疏逻辑回归#
作为一个更复杂的稀疏工作流示例,让我们考虑一个用 JAX 实现的简单逻辑回归。请注意,以下实现没有提及稀疏性
>>> import functools
>>> from sklearn.datasets import make_classification
>>> from jax.scipy import optimize
>>> def sigmoid(x):
... return 0.5 * (jnp.tanh(x / 2) + 1)
...
>>> def y_model(params, X):
... return sigmoid(jnp.dot(X, params[1:]) + params[0])
...
>>> def loss(params, X, y):
... y_hat = y_model(params, X)
... return -jnp.mean(y * jnp.log(y_hat) + (1 - y) * jnp.log(1 - y_hat))
...
>>> def fit_logreg(X, y):
... params = jnp.zeros(X.shape[1] + 1)
... result = optimize.minimize(functools.partial(loss, X=X, y=y),
... x0=params, method='BFGS')
... return result.x
>>> X, y = make_classification(n_classes=2, random_state=1701)
>>> params_dense = fit_logreg(X, y)
>>> print(params_dense)
[-0.7298445 0.29893667 1.0248291 -0.44436368 0.8785025 -0.7724008
-0.62893456 0.2934014 0.82974285 0.16838408 -0.39774987 -0.5071844
0.2028872 0.5227761 -0.3739224 -0.7104083 2.4212713 0.6310087
-0.67060554 0.03139788 -0.05359547]
这返回密集逻辑回归问题的最佳拟合参数。为了在稀疏数据上拟合相同的模型,我们可以应用 `sparsify()` 变换
>>> Xsp = sparse.BCOO.fromdense(X) # Sparse version of the input
>>> fit_logreg_sp = sparse.sparsify(fit_logreg) # Sparse-transformed fit function
>>> params_sparse = fit_logreg_sp(Xsp, y)
>>> print(params_sparse)
[-0.72971725 0.29878938 1.0246326 -0.44430563 0.8784217 -0.77225566
-0.6288222 0.29335397 0.8293481 0.16820715 -0.39764675 -0.5069753
0.202579 0.522672 -0.3740134 -0.7102678 2.4209507 0.6310593
-0.670236 0.03132951 -0.05356663]
稀疏 API 参考#
|
实验性稀疏化变换。 |
|
支持稀疏的 `jax.grad()` 版本 |
|
支持稀疏的 `jax.value_and_grad()` 版本 |
|
创建一个空的稀疏数组。 |
|
创建 2D 稀疏单位矩阵。 |
|
将输入转换为密集矩阵。 |
|
生成一个随机 BCOO 矩阵。 |
|
L2 JAX 稀疏对象的基类。 |
BCOO 数据结构#
BCOO
是批处理 COO 格式,是 `jax.experimental.sparse` 中实现的主要稀疏数据结构。它的操作与 JAX 的核心变换兼容,包括批处理(例如 `jax.vmap()`)和自动微分(例如 `jax.grad()`)。
|
JAX 中实现的实验性批处理 COO 矩阵 |
|
通过复制数据来扩展 BCOO 数组的大小和秩。 |
|
Jax.lax.concatenate() 的稀疏实现 |
|
一个通用的收缩操作。 |
|
在给定稀疏索引处计算输出的收缩操作。 |
|
Jax.lax.dynamic_slice() 的稀疏实现。 |
|
根据稀疏数组的索引从密集数组中提取值。 |
|
从密集矩阵创建 BCOO 格式的稀疏矩阵。 |
|
BCOO 版 lax.gather。 |
|
稀疏数组和密集数组之间的逐元素乘法。 |
|
两个稀疏数组的逐元素乘法。 |
|
更新 BCOO 矩阵的存储布局(即 n_batch 和 n_dense)。 |
|
沿给定轴对数组元素求和。 |
|
Jax.lax.reshape() 的稀疏实现。 |
|
Jax.lax.slice() 的稀疏实现。 |
|
对 BCOO 数组的索引进行排序。 |
|
Jax.lax.squeeze() 的稀疏实现。 |
|
对 BCOO 数组中的重复索引求和,返回一个具有排序索引的数组。 |
|
将批处理的稀疏矩阵转换为密集矩阵。 |
|
转置 BCOO 格式数组。 |
BCSR 数据结构#
BCSR
是批处理压缩稀疏行格式,目前正在开发中。其操作与 JAX 的核心变换兼容,包括批处理(例如 `jax.vmap()`)和自动微分(例如 `jax.grad()`)。
|
JAX 中实现的实验性批处理 CSR 矩阵。 |
|
一个通用的收缩操作。 |
|
根据 BCSR(索引、indptr)从密集矩阵中提取值。 |
|
从密集矩阵创建 BCSR 格式的稀疏矩阵。 |
|
将批处理的稀疏矩阵转换为密集矩阵。 |
其他稀疏数据结构#
其他稀疏数据结构包括 COO
、CSR
和 CSC
。这些是简单稀疏结构的一些核心操作的参考实现。它们的操作通常与自动微分变换(如 `jax.grad()`)兼容,但与批处理变换(如 `jax.vmap()`)不兼容。
|
JAX 中实现的实验性 COO 矩阵。 |
|
JAX 中实现的实验性 CSC 矩阵;API 可能会发生变化。 |
|
JAX 中实现的实验性 CSR 矩阵。 |
|
从密集矩阵创建 COO 格式的稀疏矩阵。 |
|
COO 稀疏矩阵与密集矩阵的乘积。 |
|
COO 稀疏矩阵与密集向量的乘积。 |
|
将 COO 格式的稀疏矩阵转换为密集矩阵。 |
|
从密集矩阵创建 CSR 格式的稀疏矩阵。 |
|
CSR 稀疏矩阵与密集矩阵的乘积。 |
|
CSR 稀疏矩阵与密集向量的乘积。 |
|
将 CSR 格式的稀疏矩阵转换为密集矩阵。 |
jax.experimental.sparse.linalg
#
稀疏线性代数例程。
|
使用 QR 分解的稀疏直接求解器。 |
|
使用 LOBPCG 例程计算前 k 个标准特征值。 |