jax.numpy.indices#
- jax.numpy.indices(dimensions: Sequence[int], dtype: str | type[Any] | dtype | SupportsDType = int32, sparse: Literal[False] = False) Array [source]#
- jax.numpy.indices(dimensions: Sequence[int], dtype: str | type[Any] | dtype | SupportsDType = int32, *, sparse: Literal[True]) tuple[Array, ...]
- jax.numpy.indices(dimensions: Sequence[int], dtype: str | type[Any] | dtype | SupportsDType = int32, sparse: bool = False) Array | tuple[Array, ...]
Return an array representing the indices of a grid.
LAX-backend implementation of
numpy.indices()
.Original docstring below.
Compute an array where the subarrays contain index values 0, 1, … varying only along the corresponding axis.
- Parameters:
dimensions (sequence of ints) – The shape of the grid.
dtype (dtype, optional) – Data type of the result.
sparse (boolean, optional) – Return a sparse representation of the grid instead of a dense representation. Default is False.
- Returns:
grid –
- If sparse is False:
Returns one array of grid indices,
grid.shape = (len(dimensions),) + tuple(dimensions)
.- If sparse is True:
Returns a tuple of arrays, with
grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)
with dimensions[i] in the ith place
- Return type:
one ndarray or tuple of ndarrays