jax.numpy.swapaxes

Contents

jax.numpy.swapaxes#

jax.numpy.swapaxes(a, axis1, axis2)[source]#

Interchange two axes of an array.

LAX-backend implementation of numpy.swapaxes().

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

Parameters:
  • a (array_like) – Input array.

  • axis1 (int) – First axis.

  • axis2 (int) – Second axis.

Returns:

a_swapped – For NumPy >= 1.10.0, if a is an ndarray, then a view of a is returned; otherwise a new array is created. For earlier NumPy versions a view of a is returned only if the order of the axes is changed, otherwise the input array is returned.

Return type:

ndarray