jax.experimental.sparse.BCOO#
- class jax.experimental.sparse.BCOO(args, *, shape, indices_sorted=False, unique_indices=False)[source]#
Experimental batched COO matrix implemented in JAX
- Parameters:
- data#
ndarray of shape
[*batch_dims, nse, *dense_dims]
containing the explicitly stored data within the sparse matrix.
- indices#
ndarray of shape
[*batch_dims, nse, n_sparse]
containing the indices of the explicitly stored data. Duplicate entries will be summed.
Examples
Create a sparse array from a dense array:
>>> M = jnp.array([[0., 2., 0.], [1., 0., 4.]]) >>> M_sp = BCOO.fromdense(M) >>> M_sp BCOO(float32[2, 3], nse=3)
Examine the internal representation:
>>> M_sp.data Array([2., 1., 4.], dtype=float32) >>> M_sp.indices Array([[0, 1], [1, 0], [1, 2]], dtype=int32)
Create a dense array from a sparse array:
>>> M_sp.todense() Array([[0., 2., 0.], [1., 0., 4.]], dtype=float32)
Create a sparse array from COO data & indices:
>>> data = jnp.array([1., 3., 5.]) >>> indices = jnp.array([[0, 0], ... [1, 1], ... [2, 2]]) >>> mat = BCOO((data, indices), shape=(3, 3)) >>> mat BCOO(float32[3, 3], nse=3) >>> mat.todense() Array([[1., 0., 0.], [0., 3., 0.], [0., 0., 5.]], dtype=float32)
Methods
__init__
(args, *, shape[, indices_sorted, ...])- param args:
astype
(*args, **kwargs)Copy the array and cast to a specified dtype.
block_until_ready
()from_scipy_sparse
(mat, *[, index_dtype, ...])Create a BCOO array from a
scipy.sparse
array.fromdense
(mat, *[, nse, index_dtype, ...])Create a BCOO array from a (dense)
Array
.reshape
(*args, **kwargs)Sum array along axis.
sort_indices
()Return a copy of the matrix with indices sorted.
sum
(*args, **kwargs)Sum array along axis.
sum_duplicates
([nse, remove_zeros])Return a copy of the array with duplicate indices summed.
todense
()Create a dense version of the array.
transpose
([axes])Create a new array containing the transpose.
tree_flatten
()tree_unflatten
(aux_data, children)update_layout
(*[, n_batch, n_dense, ...])Update the storage layout (i.e.
Attributes
T
dtype
n_batch
n_dense
n_sparse
ndim
nse
size
shape
indices_sorted
unique_indices