jax.numpy.indicesΒΆ

jax.numpy.indices(dimensions, dtype=<class 'jax._src.numpy.lax_numpy.int32'>, sparse=False)[source]ΒΆ

Return an array representing the indices of a grid.

LAX-backend implementation of indices(). Original docstring below.

Compute an array where the subarrays contain index values 0, 1, … varying only along the corresponding axis.

Parameters
  • dimensions (sequence of ints) – The shape of the grid.

  • dtype (dtype, optional) – Data type of the result.

  • sparse (boolean, optional) – Return a sparse representation of the grid instead of a dense representation. Default is False.

Returns

grid –

If sparse is False:

Returns one array of grid indices, grid.shape = (len(dimensions),) + tuple(dimensions).

If sparse is True:

Returns a tuple of arrays, with grid[i].shape = (1, ..., 1, dimensions[i], 1, ..., 1) with dimensions[i] in the ith place

Return type

one ndarray or tuple of ndarrays

See also

mgrid(), ogrid(), meshgrid()

Notes

The output shape in the dense case is obtained by prepending the number of dimensions in front of the tuple of dimensions, i.e. if dimensions is a tuple (r0, ..., rN-1) of length N, the output shape is (N, r0, ..., rN-1).

The subarrays grid[k] contains the N-D array of indices along the k-th axis. Explicitly:

grid[k, i0, i1, …, iN-1] = ik

Examples

>>> grid = np.indices((2, 3))
>>> grid.shape
(2, 2, 3)
>>> grid[0]        # row indices
array([[0, 0, 0],
       [1, 1, 1]])
>>> grid[1]        # column indices
array([[0, 1, 2],
       [0, 1, 2]])

The indices can be used as an index into an array.

>>> x = np.arange(20).reshape(5, 4)
>>> row, col = np.indices((2, 3))
>>> x[row, col]
array([[0, 1, 2],
       [4, 5, 6]])

Note that it would be more straightforward in the above example to extract the required elements directly with x[:2, :3].

If sparse is set to true, the grid will be returned in a sparse representation.

>>> i, j = np.indices((2, 3), sparse=True)
>>> i.shape
(2, 1)
>>> j.shape
(1, 3)
>>> i        # row indices
array([[0],
       [1]])
>>> j        # column indices
array([[0, 1, 2]])