jax.numpy.moveaxis

Contents

jax.numpy.moveaxis#

jax.numpy.moveaxis(a, source, destination)[source]#

Move an array axis to a new position

JAX implementation of numpy.moveaxis(), implemented in terms of jax.lax.transpose().

Parameters:
  • a (jax.typing.ArrayLike) – input array

  • source (int | Sequence[int]) – index or indices of the axes to move.

  • destination (int | Sequence[int]) – index or indices of the axes destinations

Returns:

Copy of a with axes moved from source to destination.

Return type:

Array

Notes

Unlike numpy.moveaxis(), jax.numpy.moveaxis() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.

See also

Examples

>>> a = jnp.ones((2, 3, 4, 5))

Move axis 1 to the end of the array:

>>> jnp.moveaxis(a, 1, -1).shape
(2, 4, 5, 3)

Move the last axis to position 1:

>>> jnp.moveaxis(a, -1, 1).shape
(2, 5, 3, 4)

Move multiple axes:

>>> jnp.moveaxis(a, (0, 1), (-1, -2)).shape
(4, 5, 3, 2)

This can also be accomplished via transpose():

>>> a.transpose(2, 3, 1, 0).shape
(4, 5, 3, 2)