jax.lax.dynamic_slice_in_dim

jax.lax.dynamic_slice_in_dim#

jax.lax.dynamic_slice_in_dim(operand, start_index, slice_size, axis=0)[source]#

Convenience wrapper around lax.dynamic_slice() applied to one dimension.

This is roughly equivalent to the following Python indexing syntax applied along the specified axis: operand[..., start_index:start_index + slice_size].

Parameters:
  • operand (Array | np.ndarray) – an array to slice.

  • start_index (ArrayLike) – the (possibly dynamic) start index

  • slice_size (int) – the static slice size

  • axis (int) – the axis along which to apply the slice (defaults to 0)

Return type:

Array

Returns:

An array containing the slice.

Examples

Here is a one-dimensional example:

>>> x = jnp.arange(5)
>>> dynamic_slice_in_dim(x, 1, 3)
Array([1, 2, 3], dtype=int32)

Like jax.lax.dynamic_slice, out-of-bound slices will be clipped to the valid range:

>>> dynamic_slice_in_dim(x, 4, 3)
Array([2, 3, 4], dtype=int32)

Here is a two-dimensional example:

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> dynamic_slice_in_dim(x, 1, 2, axis=1)
Array([[ 1,  2],
       [ 5,  6],
       [ 9, 10]], dtype=int32)