jax.numpy.indicesΒΆ

jax.numpy.indices(dimensions, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, sparse=False)[source]ΒΆ

Return an array representing the indices of a grid.

LAX-backend implementation of indices().

Original docstring below.

Compute an array where the subarrays contain index values 0, 1, … varying only along the corresponding axis.

Parameters
  • dimensions (sequence of ints) – The shape of the grid.

  • dtype (dtype, optional) – Data type of the result.

  • sparse (boolean, optional) – Return a sparse representation of the grid instead of a dense representation. Default is False.

Returns

grid –

If sparse is False:

Returns one array of grid indices, grid.shape = (len(dimensions),) + tuple(dimensions).

If sparse is True:

Returns a tuple of arrays, with grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1) with dimensions[i] in the ith place

Return type

one ndarray or tuple of ndarrays