jax.numpy.transposeΒΆ

jax.numpy.transpose(a, axes=None)[source]ΒΆ

Permute the dimensions of an array.

LAX-backend implementation of transpose(). Original docstring below.

Parameters
  • a (array_like) – Input array.

  • axes (list of ints, optional) – By default, reverse the dimensions, otherwise permute the axes according to the values given.

Returns

p – a with its axes permuted. A view is returned whenever possible.

Return type

ndarray

See also

moveaxis(), argsort()

Notes

Use transpose(a, argsort(axes)) to invert the transposition of tensors when using the axes keyword argument.

Transposing a 1-D array returns an unchanged view of the original array.

Examples

>>> x = np.arange(4).reshape((2,2))
>>> x
array([[0, 1],
       [2, 3]])
>>> np.transpose(x)
array([[0, 2],
       [1, 3]])
>>> x = np.ones((1, 2, 3))
>>> np.transpose(x, (1, 0, 2)).shape
(2, 1, 3)