jax.numpy.meshgrid

Contents

jax.numpy.meshgrid#

jax.numpy.meshgrid(*xi, copy=True, sparse=False, indexing='xy')[source]#

Construct N-dimensional grid arrays from N 1-dimensional vectors.

JAX implementation of numpy.meshgrid().

Parameters:
  • xi (ArrayLike) – N arrays to convert to a grid.

  • copy (bool) – whether to copy the input arrays. JAX supports only copy=True, though under JIT compilation the compiler may opt to avoid copies.

  • sparse (bool) – if False (default), then each returned arrays will be of shape [len(x1), len(x2), ..., len(xN)]. If False, then returned arrays will be of shape [1, 1, ..., len(xi), ..., 1, 1].

  • indexing (str) – options are 'xy' for cartesian indexing (default) or 'ij' for matrix indexing.

Returns:

A length-N list of grid arrays.

Return type:

list[Array]

See also

Examples

For the following examples, we’ll use these 1D arrays as inputs:

>>> x = jnp.array([1, 2])
>>> y = jnp.array([10, 20, 30])

2D cartesian mesh grid:

>>> x_grid, y_grid = jnp.meshgrid(x, y)
>>> print(x_grid)
[[1 2]
 [1 2]
 [1 2]]
>>> print(y_grid)
[[10 10]
 [20 20]
 [30 30]]

2D sparse cartesian mesh grid:

>>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True)
>>> print(x_grid)
[[1 2]]
>>> print(y_grid)
[[10]
 [20]
 [30]]

2D matrix-index mesh grid:

>>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij')
>>> print(x_grid)
[[1 1 1]
 [2 2 2]]
>>> print(y_grid)
[[10 20 30]
 [10 20 30]]