jax.numpy.indices

Contents

jax.numpy.indices#

jax.numpy.indices(dimensions: Sequence[int], dtype: str | type[Any] | dtype | SupportsDType = int32, sparse: Literal[False] = False) Array[source]#
jax.numpy.indices(dimensions: Sequence[int], dtype: str | type[Any] | dtype | SupportsDType = int32, *, sparse: Literal[True]) tuple[Array, ...]
jax.numpy.indices(dimensions: Sequence[int], dtype: str | type[Any] | dtype | SupportsDType = int32, sparse: bool = False) Array | tuple[Array, ...]

Return an array representing the indices of a grid.

LAX-backend implementation of numpy.indices().

Original docstring below.

Compute an array where the subarrays contain index values 0, 1, … varying only along the corresponding axis.

Parameters:
  • dimensions (sequence of ints) – The shape of the grid.

  • dtype (dtype, optional) – Data type of the result.

  • sparse (boolean, optional) – Return a sparse representation of the grid instead of a dense representation. Default is False.

Returns:

grid –

If sparse is False:

Returns one array of grid indices, grid.shape = (len(dimensions),) + tuple(dimensions).

If sparse is True:

Returns a tuple of arrays, with grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1) with dimensions[i] in the ith place

Return type:

one ndarray or tuple of ndarrays