jax.lax.dynamic_index_in_dim

jax.lax.dynamic_index_in_dim#

jax.lax.dynamic_index_in_dim(operand, index, axis=0, keepdims=True)[source]#

Convenience wrapper around dynamic_slice to perform int indexing.

This is roughly equivalent to the following Python indexing syntax applied along the specified axis: operand[..., index].

Parameters:
  • operand (Array | np.ndarray) – an array to slice.

  • index (int | Array) – the (possibly dynamic) start index

  • axis (int) – the axis along which to apply the slice (defaults to 0)

  • keepdims (bool) – boolean specifying whether the output should have the same rank as the input (default = True)

Return type:

Array

Returns:

An array containing the slice.

Examples

Here is a one-dimensional example:

>>> x = jnp.arange(5)
>>> dynamic_index_in_dim(x, 1)
Array([1], dtype=int32)
>>> dynamic_index_in_dim(x, 1, keepdims=False)
Array(1, dtype=int32)

Here is a two-dimensional example:

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> dynamic_index_in_dim(x, 1, axis=1, keepdims=False)
Array([1, 5, 9], dtype=int32)