jax.numpy.swapaxes#
- jax.numpy.swapaxes(a, axis1, axis2)[source]#
Swap two axes of an array.
JAX implementation of
numpy.swapaxes()
, implemented in terms ofjax.lax.transpose()
.- Parameters:
- Returns:
Copy of
a
with specified axes swapped.- Return type:
Notes
Unlike
numpy.swapaxes()
,jax.numpy.swapaxes()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize away such copies when possible, so this doesn’t have performance impacts in practice.See also
jax.numpy.moveaxis()
: move a single axis of an array.jax.numpy.rollaxis()
: older API formoveaxis
.jax.lax.transpose()
: more general axes permutations.jax.Array.swapaxes()
: same functionality via an array method.
Examples
>>> a = jnp.ones((2, 3, 4, 5)) >>> jnp.swapaxes(a, 1, 3).shape (2, 5, 4, 3)
Equivalent output via the
swapaxes
array method:>>> a.swapaxes(1, 3).shape (2, 5, 4, 3)
Equivalent output via
transpose()
:>>> a.transpose(0, 3, 2, 1).shape (2, 5, 4, 3)