jax.lax.dynamic_slice#
- jax.lax.dynamic_slice(operand, start_indices, slice_sizes)[source]#
Wraps XLA’s DynamicSlice operator.
- Parameters:
operand (Array | np.ndarray) – an array to slice.
start_indices (Array | np.ndarray | Sequence[ArrayLike]) – a list of scalar indices, one per dimension. These values may be dynamic.
slice_sizes (Shape) – the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).
- Returns:
An array containing the slice.
- Return type:
Examples
Here is a simple two-dimensional dynamic slice:
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3)) Array([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size:
>>> dynamic_slice(x, (1, 1), (2, 4)) Array([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)