jax.numpy.mgrid

jax.numpy.mgrid = <jax._src.numpy.lax_numpy._Mgrid object>

Return dense multi-dimensional “meshgrid”.

LAX-backend implementation of numpy.mgrid. This is a convenience wrapper for functionality provided by jax.numpy.meshgrid() with sparse=False.

See also

jnp.ogrid: open/sparse version of jnp.mgrid

Examples

Pass [start:stop:step] to generate values similar to jax.numpy.arange():

>>> jnp.mgrid[0:4:1]
DeviceArray([0, 1, 2, 3], dtype=int32)

Passing an imaginary step generates values similar to jax.numpy.linspace():

>>> jnp.mgrid[0:1:4j]
DeviceArray([0.        , 0.33333334, 0.6666667 , 1.        ], dtype=float32)

Multiple slices can be used to create broadcasted grids of indices:

>>> jnp.mgrid[:2, :3]
DeviceArray([[[0, 0, 0],
              [1, 1, 1]],
             [[0, 1, 2],
              [0, 1, 2]]], dtype=int32)