jax.numpy.roll

Contents

jax.numpy.roll#

jax.numpy.roll(a, shift, axis=None)[source]#

Roll array elements along a given axis.

LAX-backend implementation of numpy.roll().

Original docstring below.

Elements that roll beyond the last position are re-introduced at the first.

Parameters:
  • a (array_like) – Input array.

  • shift (int or tuple of ints) – The number of places by which elements are shifted. If a tuple, then axis must be a tuple of the same size, and each of the given axes is shifted by the corresponding number. If an int while axis is a tuple of ints, then the same value is used for all given axes.

  • axis (int or tuple of ints, optional) – Axis or axes along which elements are shifted. By default, the array is flattened before shifting, after which the original shape is restored.

Returns:

res – Output array, with the same shape as a.

Return type:

ndarray