jax.lax.dynamic_slice
jax.lax.dynamic_slice#
- jax.lax.dynamic_slice(operand, start_indices, slice_sizes)[source]#
Wraps XLAβs DynamicSlice operator.
- Parameters
operand (
Array
) β an array to slice.start_indices (
Union
[Array
,Sequence
[Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]]]) β a list of scalar indices, one per dimension. These values may be dynamic.slice_sizes (
Sequence
[Union
[int
,Any
]]) β the size of the slice. Must be a sequence of non-negative integers with length equal to ndim(operand). Inside a JIT compiled function, only static values are supported (all JAX arrays inside JIT must have statically known size).
- Return type
Array
- Returns
An array containing the slice.
Examples
Here is a simple two-dimensional dynamic slice:
>>> x = jnp.arange(12).reshape(3, 4) >>> x Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)
>>> dynamic_slice(x, (1, 1), (2, 3)) Array([[ 5, 6, 7], [ 9, 10, 11]], dtype=int32)
Note the potentially surprising behavior for the case where the requested slice overruns the bounds of the array; in this case the start index is adjusted to return a slice of the requested size:
>>> dynamic_slice(x, (1, 1), (2, 4)) Array([[ 4, 5, 6, 7], [ 8, 9, 10, 11]], dtype=int32)