jax.numpy.meshgrid#
- jax.numpy.meshgrid(*xi, copy=True, sparse=False, indexing='xy')[source]#
Construct N-dimensional grid arrays from N 1-dimensional vectors.
JAX implementation of
numpy.meshgrid()
.- Parameters:
xi (ArrayLike) – N arrays to convert to a grid.
copy (bool) – whether to copy the input arrays. JAX supports only
copy=True
, though under JIT compilation the compiler may opt to avoid copies.sparse (bool) – if False (default), then each returned arrays will be of shape
[len(x1), len(x2), ..., len(xN)]
. If False, then returned arrays will be of shape[1, 1, ..., len(xi), ..., 1, 1]
.indexing (str) – options are
'xy'
for cartesian indexing (default) or'ij'
for matrix indexing.
- Returns:
A length-N list of grid arrays.
- Return type:
See also
jax.numpy.mgrid
: create a meshgrid using indexing syntax.jax.numpy.ogrid
: create an open meshgrid using indexing syntax.
Examples
For the following examples, we’ll use these 1D arrays as inputs:
>>> x = jnp.array([1, 2]) >>> y = jnp.array([10, 20, 30])
2D cartesian mesh grid:
>>> x_grid, y_grid = jnp.meshgrid(x, y) >>> print(x_grid) [[1 2] [1 2] [1 2]] >>> print(y_grid) [[10 10] [20 20] [30 30]]
2D sparse cartesian mesh grid:
>>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True) >>> print(x_grid) [[1 2]] >>> print(y_grid) [[10] [20] [30]]
2D matrix-index mesh grid:
>>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij') >>> print(x_grid) [[1 1 1] [2 2 2]] >>> print(y_grid) [[10 20 30] [10 20 30]]