jax.numpy.indices#

jax.numpy.indices(dimensions: Sequence[int], dtype: Union[Any, str, numpy.dtype, jax._src.typing.SupportsDType] = int32, sparse: Literal[False] = False) jax.Array[source]#
jax.numpy.indices(dimensions: Sequence[int], dtype: Union[Any, str, numpy.dtype, jax._src.typing.SupportsDType] = int32, *, sparse: Literal[True]) Tuple[jax.Array, ...]
jax.numpy.indices(dimensions: Sequence[int], dtype: Union[Any, str, numpy.dtype, jax._src.typing.SupportsDType] = int32, sparse: bool = False) Union[jax.Array, Tuple[jax.Array, ...]]

Return an array representing the indices of a grid.

LAX-backend implementation of numpy.indices().

Original docstring below.

Compute an array where the subarrays contain index values 0, 1, … varying only along the corresponding axis.

Parameters
  • dimensions (sequence of ints) – The shape of the grid.

  • dtype (dtype, optional) – Data type of the result.

  • sparse (boolean, optional) – Return a sparse representation of the grid instead of a dense representation. Default is False.

Returns

grid

If sparse is False:

Returns one array of grid indices, grid.shape = (len(dimensions),) + tuple(dimensions).

If sparse is True:

Returns a tuple of arrays, with grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1) with dimensions[i] in the ith place

Return type

one ndarray or tuple of ndarrays