jax.numpy.indices(dimensions, dtype=<class 'jax.numpy.int32'>, sparse=False)[source]#

Return an array representing the indices of a grid.

LAX-backend implementation of numpy.indices().

Original docstring below.

Compute an array where the subarrays contain index values 0, 1, … varying only along the corresponding axis.

  • dimensions (sequence of ints) – The shape of the grid.

  • dtype (dtype, optional) – Data type of the result.

  • sparse (boolean, optional) – Return a sparse representation of the grid instead of a dense representation. Default is False.



If sparse is False:

Returns one array of grid indices, grid.shape = (len(dimensions),) + tuple(dimensions).

If sparse is True:

Returns a tuple of arrays, with grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1) with dimensions[i] in the ith place

Return type:

one ndarray or tuple of ndarrays