jax.numpy.rollaxis

jax.numpy.rollaxis(a, axis, start=0)[source]

Roll the specified axis backwards, until it lies in a given position.

LAX-backend implementation of rollaxis().

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

This function continues to be supported for backward compatibility, but you should prefer moveaxis. The moveaxis function was added in NumPy 1.11.

Parameters
  • a (ndarray) – Input array.

  • axis (int) – The axis to be rolled. The positions of the other axes do not change relative to one another.

  • start (int, optional) –

    When start <= axis, the axis is rolled back until it lies in this position. When start > axis, the axis is rolled until it lies before this position. The default, 0, results in a “complete” roll. The following table describes how negative values of start are interpreted:

    start

    Normalized start

    -(arr.ndim+1)

    raise AxisError

    -arr.ndim

    0

    -1

    arr.ndim-1

    0

    0

    arr.ndim

    arr.ndim

    arr.ndim + 1

    raise AxisError

Returns

res – For NumPy >= 1.10.0 a view of a is always returned. For earlier NumPy versions a view of a is returned only if the order of the axes is changed, otherwise the input array is returned.

Return type

ndarray