jax.numpy.ogrid

Contents

jax.numpy.ogrid#

jax.numpy.ogrid = <jax._src.numpy.index_tricks._Ogrid object>#

Return open multi-dimensional “meshgrid”.

LAX-backend implementation of numpy.ogrid. This is a convenience wrapper for functionality provided by jax.numpy.meshgrid() with sparse=True.

See also

jnp.mgrid: dense version of jnp.ogrid

Examples

Pass [start:stop:step] to generate values similar to jax.numpy.arange():

>>> jnp.ogrid[0:4:1]
Array([0, 1, 2, 3], dtype=int32)

Passing an imaginary step generates values similar to jax.numpy.linspace():

>>> jnp.ogrid[0:1:4j]
Array([0.        , 0.33333334, 0.6666667 , 1.        ], dtype=float32)

Multiple slices can be used to create sparse grids of indices:

>>> jnp.ogrid[:2, :3]
[Array([[0],
        [1]], dtype=int32),
 Array([[0, 1, 2]], dtype=int32)]