jax.numpy.ogrid#
- jax.numpy.ogrid = <jax._src.numpy.index_tricks._Ogrid object>#
Return open multi-dimensional “meshgrid”.
LAX-backend implementation of
numpy.ogrid
. This is a convenience wrapper for functionality provided byjax.numpy.meshgrid()
withsparse=True
.See also
jnp.mgrid: dense version of jnp.ogrid
Examples
Pass
[start:stop:step]
to generate values similar tojax.numpy.arange()
:>>> jnp.ogrid[0:4:1] Array([0, 1, 2, 3], dtype=int32)
Passing an imaginary step generates values similar to
jax.numpy.linspace()
:>>> jnp.ogrid[0:1:4j] Array([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32)
Multiple slices can be used to create sparse grids of indices:
>>> jnp.ogrid[:2, :3] [Array([[0], [1]], dtype=int32), Array([[0, 1, 2]], dtype=int32)]