jax.numpy.rollaxis

Contents

jax.numpy.rollaxis#

jax.numpy.rollaxis(a, axis, start=0)[source]#

Roll the specified axis to a given position.

JAX implementation of numpy.rollaxis().

This function exists for compatibility with NumPy, but in most cases the newer jax.numpy.moveaxis() instead, because the meaning of its arguments is more intuitive.

Parameters:
  • a (jax.typing.ArrayLike) – input array.

  • axis (int) – index of the axis to roll forward.

  • start (int) – index toward which the axis will be rolled (default = 0). After normalizing negative axes, if start <= axis, the axis is rolled to the start index; if start > axis, the axis is rolled until the position before start.

Returns:

Copy of a with rolled axis.

Return type:

Array

Notes

Unlike numpy.rollaxis(), jax.numpy.rollaxis() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>> a = jnp.ones((2, 3, 4, 5))

Roll axis 2 to the start of the array:

>>> jnp.rollaxis(a, 2).shape
(4, 2, 3, 5)

Roll axis 1 to the end of the array:

>>> jnp.rollaxis(a, 1, a.ndim).shape
(2, 4, 5, 3)

Equivalent of these two with moveaxis()

>>> jnp.moveaxis(a, 2, 0).shape
(4, 2, 3, 5)
>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)