jax.numpy.moveaxis

Contents

jax.numpy.moveaxis#

jax.numpy.moveaxis(a, source, destination)[source]#

Move axes of an array to new positions.

LAX-backend implementation of numpy.moveaxis().

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

Other axes remain in their original order.

Added in version 1.11.0.

Parameters:
  • a (np.ndarray) – The array whose axes should be reordered.

  • source (int or sequence of int) – Original positions of the axes to move. These must be unique.

  • destination (int or sequence of int) – Destination positions for each of the original axes. These must also be unique.

Returns:

result – Array with moved axes. This array is a view of the input array.

Return type:

np.ndarray