jax.numpy.permute_dims# jax.numpy.permute_dims(a, /, axes)[source]# Parameters: a (jax.typing.ArrayLike) axes (tuple[int, ...]) Return type: Array