jax.numpy.moveaxisΒΆ

jax.numpy.moveaxis(a, source, destination)[source]ΒΆ

Move axes of an array to new positions.

LAX-backend implementation of moveaxis().

Original docstring below.

Other axes remain in their original order.

New in version 1.11.0.

Parameters
  • a (np.ndarray) – The array whose axes should be reordered.

  • source (int or sequence of int) – Original positions of the axes to move. These must be unique.

  • destination (int or sequence of int) – Destination positions for each of the original axes. These must also be unique.

Returns

result – Array with moved axes. This array is a view of the input array.

Return type

np.ndarray