# jax.lax.slice_in_dim#

jax.lax.slice_in_dim(operand, start_index, limit_index, stride=1, axis=0)[source]#

Convenience wrapper around `lax.slice()` applying to only one dimension.

This is effectively equivalent to `operand[..., start_index:limit_index:stride]` with the indexing applied on the specified axis.

Parameters:
• operand (Array | ndarray) â€“ an array to slice.

• start_index (int | None) â€“ an optional start index (defaults to zero)

• limit_index (int | None) â€“ an optional end index (defaults to operand.shape[axis])

• stride (int) â€“ an optional stride (defaults to 1)

• axis (int) â€“ the axis along which to apply the slice (defaults to 0)

Returns:

An array containing the slice.

Return type:

Array

Examples

Here is a one-dimensional example:

```>>> x = jnp.arange(4)
>>> lax.slice_in_dim(x, 1, 3)
Array([1, 2], dtype=int32)
```

Here are some two-dimensional examples:

```>>> x = jnp.arange(12).reshape(4, 3)
>>> x
Array([[ 0,  1,  2],
[ 3,  4,  5],
[ 6,  7,  8],
[ 9, 10, 11]], dtype=int32)
```
```>>> lax.slice_in_dim(x, 1, 3)
Array([[3, 4, 5],
[6, 7, 8]], dtype=int32)
```
```>>> lax.slice_in_dim(x, 1, 3, axis=1)
Array([[ 1,  2],
[ 4,  5],
[ 7,  8],
[10, 11]], dtype=int32)
```