jax.experimental.sparse.random_bcoo

Contents

jax.experimental.sparse.random_bcoo#

jax.experimental.sparse.random_bcoo(key, shape, *, dtype=<class 'jax.numpy.float64'>, indices_dtype=None, nse=0.2, n_batch=0, n_dense=0, unique_indices=True, sorted_indices=False, generator=<function uniform>, **kwds)[source]#

Generate a random BCOO matrix.

Parameters:
  • key – PRNG key to be passed to generator function.

  • shape – tuple specifying the shape of the array to be generated.

  • dtype – dtype of the array to be generated.

  • indices_dtype – dtype of the BCOO indices.

  • nse – number of specified elements in the matrix, or if 0 < nse < 1, a fraction of sparse dimensions to be specified (default: 0.2).

  • n_batch – number of batch dimensions. must satisfy n_batch >= 0 and n_batch + n_dense <= len(shape).

  • n_dense – number of batch dimensions. must satisfy n_dense >= 0 and n_batch + n_dense <= len(shape).

  • unique_indices – boolean specifying whether indices should be unique (default: True).

  • sorted_indices – boolean specifying whether indices should be row-sorted in lexicographical order (default: False).

  • generator – function for generating random values accepting a key, shape, and dtype. It defaults to jax.random.uniform(), and may be any function with a similar signature.

  • **kwds – additional keyword arguments to pass to generator.

Returns:

a sparse.BCOO array with the specified properties.

Return type:

arr