jax.numpy.rollaxis(a, axis, start=0)[source]#

Roll the specified axis to a given position.

JAX implementation of numpy.rollaxis().

This function exists for compatibility with NumPy, but in most cases the newer jax.numpy.moveaxis() instead, because the meaning of its arguments is more intuitive.

  • a (jax.typing.ArrayLike) – input array.

  • axis (int) – index of the axis to roll forward.

  • start (int) – index toward which the axis will be rolled (default = 0). After normalizing negative axes, if start <= axis, the axis is rolled to the start index; if start > axis, the axis is rolled until the position before start.


Copy of a with rolled axis.

Return type:



Unlike numpy.rollaxis(), jax.numpy.rollaxis() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.

See also


>>> a = jnp.ones((2, 3, 4, 5))

Roll axis 2 to the start of the array:

>>> jnp.rollaxis(a, 2).shape
(4, 2, 3, 5)

Roll axis 1 to the end of the array:

>>> jnp.rollaxis(a, 1, a.ndim).shape
(2, 4, 5, 3)

Equivalent of these two with moveaxis()

>>> jnp.moveaxis(a, 2, 0).shape
(4, 2, 3, 5)
>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)