jax.numpy.roll#
- jax.numpy.roll(a, shift, axis=None)[source]#
Roll array elements along a given axis.
LAX-backend implementation of
numpy.roll()
.Original docstring below.
Elements that roll beyond the last position are re-introduced at the first.
- Parameters
a (array_like) – Input array.
shift (int or tuple of ints) – The number of places by which elements are shifted. If a tuple, then axis must be a tuple of the same size, and each of the given axes is shifted by the corresponding number. If an int while axis is a tuple of ints, then the same value is used for all given axes.
axis (int or tuple of ints, optional) – Axis or axes along which elements are shifted. By default, the array is flattened before shifting, after which the original shape is restored.
- Returns
res – Output array, with the same shape as a.
- Return type
ndarray