jax.numpy.indices#
- jax.numpy.indices(dimensions, dtype=None, sparse=False)[source]#
Generate arrays of grid indices.
JAX implementation of
numpy.indices()
.- Parameters:
- Returns:
An array of shape
(len(dimensions), *dimensions)
Ifsparse
is False, or a sequence of arrays of the same length asdimensions
ifsparse
is True.- Return type:
See also
jax.numpy.meshgrid()
: generate a grid from arbitrary input arrays.jax.numpy.mgrid
: generate dense indices using a slicing syntax.jax.numpy.ogrid
: generate sparse indices using a slicing syntax.
Examples
>>> jnp.indices((2, 3)) Array([[[0, 0, 0], [1, 1, 1]], [[0, 1, 2], [0, 1, 2]]], dtype=int32) >>> jnp.indices((2, 3), sparse=True) (Array([[0], [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32))