jax.lax.slice_in_dim#
- jax.lax.slice_in_dim(operand, start_index, limit_index, stride=1, axis=0)[source]#
Convenience wrapper around
lax.slice()
applying to only one dimension.This is effectively equivalent to
operand[..., start_index:limit_index:stride]
with the indexing applied on the specified axis.- Parameters:
- Returns:
An array containing the slice.
- Return type:
Examples
Here is a one-dimensional example:
>>> x = jnp.arange(4) >>> lax.slice_in_dim(x, 1, 3) Array([1, 2], dtype=int32)
Here are some two-dimensional examples:
>>> x = jnp.arange(12).reshape(4, 3) >>> x Array([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3) Array([[3, 4, 5], [6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1) Array([[ 1, 2], [ 4, 5], [ 7, 8], [10, 11]], dtype=int32)