jax.lax.slice_in_dim#
- jax.lax.slice_in_dim(operand, start_index, limit_index, stride=1, axis=0)[source]#
Convenience wrapper around
lax.slice()
applying to only one dimension.This is effectively equivalent to
operand[..., start_index:limit_index:stride]
with the indexing applied on the specified axis.- Parameters:
operand (Array | np.ndarray) – an array to slice.
start_index (int | None) – an optional start index (defaults to zero)
limit_index (int | None) – an optional end index (defaults to operand.shape[axis])
stride (int) – an optional stride (defaults to 1)
axis (int) – the axis along which to apply the slice (defaults to 0)
- Returns:
An array containing the slice.
- Return type:
Examples
Here is a one-dimensional example:
>>> x = jnp.arange(4) >>> lax.slice_in_dim(x, 1, 3) Array([1, 2], dtype=int32)
Here are some two-dimensional examples:
>>> x = jnp.arange(12).reshape(4, 3) >>> x Array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3) Array([[3, 4, 5], [6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1) Array([[ 1, 2], [ 4, 5], [ 7, 8], [10, 11]], dtype=int32)