jax.numpy.roll#

jax.numpy.roll(a, shift, axis=None)[source]#

Roll the elements of an array along a specified axis.

JAX implementation of numpy.roll().

Parameters:
  • a (ArrayLike) – input array.

  • shift (ArrayLike | Sequence[int]) – the number of positions to shift the specified axis. If an integer, all axes are shifted by the same amount. If a tuple, the shift for each axis is specified individually.

  • axis (int | Sequence[int] | None | None) – the axis or axes to roll. If None, the array is flattened, shifted, and then reshaped to its original shape.

Returns:

A copy of a with elements rolled along the specified axis or axes.

Return type:

Array

See also

Examples

>>> a = jnp.array([0, 1, 2, 3, 4, 5])
>>> jnp.roll(a, 2)
Array([4, 5, 0, 1, 2, 3], dtype=int32)

Roll elements along a specific axis:

>>> a = jnp.array([[ 0,  1,  2,  3],
...                [ 4,  5,  6,  7],
...                [ 8,  9, 10, 11]])
>>> jnp.roll(a, 1, axis=0)
Array([[ 8,  9, 10, 11],
       [ 0,  1,  2,  3],
       [ 4,  5,  6,  7]], dtype=int32)
>>> jnp.roll(a, [2, 3], axis=[0, 1])
Array([[ 5,  6,  7,  4],
       [ 9, 10, 11,  8],
       [ 1,  2,  3,  0]], dtype=int32)