jax.experimental.sparse.BCSR

jax.experimental.sparse.BCSR#

class jax.experimental.sparse.BCSR(args, *, shape, indices_sorted=False, unique_indices=False)[source]#

Experimental batched CSR matrix implemented in JAX.

Parameters:
__init__(args, *, shape, indices_sorted=False, unique_indices=False)[source]#
Parameters:

Methods

__init__(args, *, shape[, indices_sorted, ...])

param args:

block_until_ready()

from_bcoo(arr)

param arr:

from_scipy_sparse(mat, *[, index_dtype, ...])

Create a BCSR array from a scipy.sparse array.

fromdense(mat, *[, nse, index_dtype, ...])

Create a BCSR array from a (dense) Array.

sum(*args, **kwargs)

sum_duplicates([nse, remove_zeros])

Return a copy of the array with duplicate indices summed.

to_bcoo()

rtype:

BCOO

todense()

Create a dense version of the array.

transpose(*args, **kwargs)

tree_flatten()

tree_unflatten(aux_data, children)

Attributes

T

dtype

n_batch

n_dense

n_sparse

ndim

nse

size

data

indices

indptr

shape

indices_sorted

unique_indices