jax.numpy.permute_dims#
- jax.numpy.permute_dims(a, /, axes)[source]#
Permute the axes/dimensions of an array.
JAX implementation of
array_api.permute_dims()
.- Parameters:
- Returns:
a copy of
a
with axes permuted.- Return type:
Examples
>>> a = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.permute_dims(a, (1, 0)) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)