jax.numpy.indices#
- jax.numpy.indices(dimensions, dtype=<class 'jax.numpy.int32'>, sparse=False)[source]#
Return an array representing the indices of a grid.
LAX-backend implementation of
numpy.indices()
.Original docstring below.
Compute an array where the subarrays contain index values 0, 1, … varying only along the corresponding axis.
- Parameters:
dimensions (sequence of ints) – The shape of the grid.
dtype (dtype, optional) – Data type of the result.
sparse (boolean, optional) – Return a sparse representation of the grid instead of a dense representation. Default is False.
- Returns:
grid –
- If sparse is False:
Returns one array of grid indices,
grid.shape = (len(dimensions),) + tuple(dimensions)
.- If sparse is True:
Returns a tuple of arrays, with
grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1)
with dimensions[i] in the ith place
- Return type:
one ndarray or tuple of ndarrays