jax.numpy.permute_dims#

jax.numpy.permute_dims(a, /, axes)[source]#

Permute the axes/dimensions of an array.

JAX implementation of array_api.permute_dims().

Parameters:
  • a (ArrayLike) – input array

  • axes (tuple[int, ...]) – tuple of integers in range [0, a.ndim) specifying the axes permutation.

Returns:

a copy of a with axes permuted.

Return type:

Array

Examples

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.permute_dims(a, (1, 0))
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)