jax.numpy.permute_dims

Contents

jax.numpy.permute_dims#

jax.numpy.permute_dims(a, /, axes)[source]#
Parameters:
  • a (jax.typing.ArrayLike)

  • axes (tuple[int, ...])

Return type:

Array