jax.lax.index_in_dim

Contents

jax.lax.index_in_dim#

jax.lax.index_in_dim(operand, index, axis=0, keepdims=True)[source]#

Convenience wrapper around lax.slice() to perform int indexing.

This is effectively equivalent to operand[..., start_index:limit_index:stride] with the indexing applied on the specified axis.

Parameters:
  • operand (Array | np.ndarray) – an array to index.

  • index (int) – integer index

  • axis (int) – the axis along which to apply the index (defaults to 0)

  • keepdims (bool) – boolean specifying whether the output array should preserve the rank of the input (default=True)

Return type:

Array

Returns:

The subarray at the specified index.

Examples

Here is a one-dimensional example:

>>> x = jnp.arange(4)
>>> lax.index_in_dim(x, 2)
Array([2], dtype=int32)
>>> lax.index_in_dim(x, 2, keepdims=False)
Array(2, dtype=int32)

Here are some two-dimensional examples:

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> lax.index_in_dim(x, 1)
Array([[4, 5, 6, 7]], dtype=int32)
>>> lax.index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)