jax.lax.dynamic_index_in_dim#
- jax.lax.dynamic_index_in_dim(operand, index, axis=0, keepdims=True)[source]#
Convenience wrapper around dynamic_slice to perform int indexing.
This is roughly equivalent to the following Python indexing syntax applied along the specified axis:
operand[..., index]
.- Parameters:
operand (Array | np.ndarray) – an array to slice.
index (int | Array) – the (possibly dynamic) start index
axis (int) – the axis along which to apply the slice (defaults to 0)
keepdims (bool) – boolean specifying whether the output should have the same rank as the input (default = True)
- Return type:
Array
- Returns:
An array containing the slice.
Examples
Here is a one-dimensional example:
>>> x = jnp.arange(5) >>> dynamic_index_in_dim(x, 1) Array([1], dtype=int32)
>>> dynamic_index_in_dim(x, 1, keepdims=False) Array(1, dtype=int32)
Here is a two-dimensional example:
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_index_in_dim(x, 1, axis=1, keepdims=False) Array([1, 5, 9], dtype=int32)