jax.numpy.ix_

Contents

jax.numpy.ix_#

jax.numpy.ix_(*args)[source]#

Return a multi-dimensional grid (open mesh) from N one-dimensional sequences.

JAX implementation of numpy.ix_().

Parameters:

*args (jax.typing.ArrayLike) – N one-dimensional arrays

Returns:

Tuple of Jax arrays forming an open mesh, each with N dimensions.

Return type:

tuple[Array, …]

Example

>>> rows = jnp.array([0, 2])
>>> cols = jnp.array([1, 3])
>>> open_mesh = jnp.ix_(rows, cols)
>>> open_mesh
(Array([[0],
      [2]], dtype=int32), Array([[1, 3]], dtype=int32))
>>> [grid.shape for grid in open_mesh]
[(2, 1), (1, 2)]
>>> x = jnp.array([[10, 20, 30, 40],
...                [50, 60, 70, 80],
...                [90, 100, 110, 120],
...                [130, 140, 150, 160]])
>>> x[open_mesh]
Array([[ 20,  40],
       [100, 120]], dtype=int32)