jax.experimental.sparse.BCOO#

class jax.experimental.sparse.BCOO(args, *, shape, indices_sorted=False, unique_indices=False)[source]#

Experimental batched COO matrix implemented in JAX

Parameters
  • (data – data and indices in batched COO format.

  • indices) – data and indices in batched COO format.

  • shape (Sequence[int]) – shape of sparse array.

data#

ndarray of shape [*batch_dims, nse, *dense_dims] containing the explicitly stored data within the sparse matrix.

Type

jax.Array

indices#

ndarray of shape [*batch_dims, nse, n_sparse] containing the indices of the explicitly stored data. Duplicate entries will be summed.

Type

jax.Array

Examples

Create a sparse array from a dense array:

>>> M = jnp.array([[0., 2., 0.], [1., 0., 4.]])
>>> M_sp = BCOO.fromdense(M)
>>> M_sp
BCOO(float32[2, 3], nse=3)

Examine the internal representation:

>>> M_sp.data
Array([2., 1., 4.], dtype=float32)
>>> M_sp.indices
Array([[0, 1],
       [1, 0],
       [1, 2]], dtype=int32)

Create a dense array from a sparse array:

>>> M_sp.todense()
Array([[0., 2., 0.],
       [1., 0., 4.]], dtype=float32)

Create a sparse array from COO data & indices:

>>> data = jnp.array([1., 3., 5.])
>>> indices = jnp.array([[0, 0],
...                      [1, 1],
...                      [2, 2]])
>>> mat = BCOO((data, indices), shape=(3, 3))
>>> mat
BCOO(float32[3, 3], nse=3)
>>> mat.todense()
Array([[1., 0., 0.],
       [0., 3., 0.],
       [0., 0., 5.]], dtype=float32)
Parameters
  • args (Tuple[Array, Array]) –

  • indices_sorted (bool) –

  • unique_indices (bool) –

__init__(args, *, shape, indices_sorted=False, unique_indices=False)[source]#
Parameters

Methods

__init__(args, *, shape[, indices_sorted, ...])

param args

astype(*args, **kwargs)

Copy the array and cast to a specified dtype.

block_until_ready()

from_scipy_sparse(mat, *[, index_dtype, ...])

Create a BCOO array from a scipy.sparse array.

fromdense(mat, *[, nse, index_dtype, ...])

Create a BCOO array from a (dense) DeviceArray.

reshape(*args, **kwargs)

Sum array along axis.

sort_indices()

Return a copy of the matrix with indices sorted.

sum(*args, **kwargs)

Sum array along axis.

sum_duplicates([nse, remove_zeros])

Return a copy of the array with duplicate indices summed.

todense()

Create a dense version of the array.

transpose([axes])

Create a new array containing the transpose.

tree_flatten()

tree_unflatten(aux_data, children)

update_layout(*[, n_batch, n_dense, ...])

Update the storage layout (i.e.

Attributes

T

dtype

n_batch

n_dense

n_sparse

ndim

rtype

int

nse

size

rtype

int

data

indices

shape

indices_sorted

unique_indices