jax.lax.slice_in_dim

Contents

jax.lax.slice_in_dim#

jax.lax.slice_in_dim(operand, start_index, limit_index, stride=1, axis=0)[source]#

Convenience wrapper around lax.slice() applying to only one dimension.

This is effectively equivalent to operand[..., start_index:limit_index:stride] with the indexing applied on the specified axis.

Parameters:
  • operand (Array | np.ndarray) – an array to slice.

  • start_index (int | None) – an optional start index (defaults to zero)

  • limit_index (int | None) – an optional end index (defaults to operand.shape[axis])

  • stride (int) – an optional stride (defaults to 1)

  • axis (int) – the axis along which to apply the slice (defaults to 0)

Return type:

Array

Returns:

An array containing the slice.

Examples

Here is a one-dimensional example:

>>> x = jnp.arange(4)
>>> lax.slice_in_dim(x, 1, 3)
Array([1, 2], dtype=int32)

Here are some two-dimensional examples:

>>> x = jnp.arange(12).reshape(4, 3)
>>> x
Array([[ 0,  1,  2],
       [ 3,  4,  5],
       [ 6,  7,  8],
       [ 9, 10, 11]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3)
Array([[3, 4, 5],
       [6, 7, 8]], dtype=int32)
>>> lax.slice_in_dim(x, 1, 3, axis=1)
Array([[ 1,  2],
       [ 4,  5],
       [ 7,  8],
       [10, 11]], dtype=int32)