jax.numpy.roll#
- jax.numpy.roll(a, shift, axis=None)[source]#
Roll the elements of an array along a specified axis.
JAX implementation of
numpy.roll()
.- Parameters:
a (ArrayLike) – input array.
shift (ArrayLike | Sequence[int]) – the number of positions to shift the specified axis. If an integer, all axes are shifted by the same amount. If a tuple, the shift for each axis is specified individually.
axis (int | Sequence[int] | None | None) – the axis or axes to roll. If
None
, the array is flattened, shifted, and then reshaped to its original shape.
- Returns:
A copy of
a
with elements rolled along the specified axis or axes.- Return type:
See also
jax.numpy.rollaxis()
: roll the specified axis to a given position.
Examples
>>> a = jnp.array([0, 1, 2, 3, 4, 5]) >>> jnp.roll(a, 2) Array([4, 5, 0, 1, 2, 3], dtype=int32)
Roll elements along a specific axis:
>>> a = jnp.array([[ 0, 1, 2, 3], ... [ 4, 5, 6, 7], ... [ 8, 9, 10, 11]]) >>> jnp.roll(a, 1, axis=0) Array([[ 8, 9, 10, 11], [ 0, 1, 2, 3], [ 4, 5, 6, 7]], dtype=int32) >>> jnp.roll(a, [2, 3], axis=[0, 1]) Array([[ 5, 6, 7, 4], [ 9, 10, 11, 8], [ 1, 2, 3, 0]], dtype=int32)