jax.numpy.rollaxis(a, axis, start=0)[source]

Roll the specified axis backwards, until it lies in a given position.

LAX-backend implementation of rollaxis(). Original docstring below.

This function continues to be supported for backward compatibility, but you should prefer moveaxis. The moveaxis function was added in NumPy 1.11.

  • a (ndarray) – Input array.

  • axis (int) – The axis to roll backwards. The positions of the other axes do not change relative to one another.

  • start (int, optional) – The axis is rolled until it lies before this position. The default, 0, results in a “complete” roll.


res – For NumPy >= 1.10.0 a view of a is always returned. For earlier NumPy versions a view of a is returned only if the order of the axes is changed, otherwise the input array is returned.

Return type


See also


Move array axes to new positions.


Roll the elements of an array by a number of positions along a given axis.


>>> a = np.ones((3,4,5,6))
>>> np.rollaxis(a, 3, 1).shape
(3, 6, 4, 5)
>>> np.rollaxis(a, 2).shape
(5, 3, 4, 6)
>>> np.rollaxis(a, 1, 4).shape
(3, 5, 6, 4)