jax.experimental.sparse.JAXSparse#
- class jax.experimental.sparse.JAXSparse(args, *, shape)[source]#
Base class for high-level JAX sparse objects.
Methods
__init__
(args, *, shape)- param args:
block_until_ready
()sum
(*args, **kwargs)transpose
([axes])tree_flatten
()tree_unflatten
(aux_data, children)Attributes
T
ndim
size
data
shape
nse
dtype