jax.experimental.sparse.JAXSparse

jax.experimental.sparse.JAXSparse#

class jax.experimental.sparse.JAXSparse(args, *, shape)[source]#

Base class for high-level JAX sparse objects.

Parameters:
__init__(args, *, shape)[source]#
Parameters:

Methods

__init__(args, *, shape)

param args:

block_until_ready()

sum(*args, **kwargs)

transpose([axes])

tree_flatten()

tree_unflatten(aux_data, children)

Attributes

T

ndim

size

data

shape

nse

dtype