jax.lax.dynamic_slice

Contents

jax.lax.dynamic_slice#

jax.lax.dynamic_slice(operand, start_indices, slice_sizes)[source]#

Wraps XLA’s DynamicSlice operator.

Parameters:
  • operand (Array | np.ndarray) – an array to slice.

  • start_indices (Array | np.ndarray | Sequence[ArrayLike]) – a list of scalar indices, one per dimension. These values may be dynamic.

  • slice_sizes (Shape) – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).

Return type:

Array

Returns:

An array containing the slice.

Examples

Here is a simple two-dimensional dynamic slice:

>>> x = jnp.arange(12).reshape(3, 4)
>>> x
Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3))
Array([[ 5,  6,  7],
       [ 9, 10, 11]], dtype=int32)

Note the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size:

>>> dynamic_slice(x, (1, 1), (2, 4))
Array([[ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)