jax.numpy.transpose

Contents

jax.numpy.transpose#

jax.numpy.transpose(a, axes=None)[source]#

Returns an array with axes transposed.

LAX-backend implementation of numpy.transpose().

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

For a 1-D array, this returns an unchanged view of the original array, as a transposed vector is simply the same vector. To convert a 1-D array into a 2-D column vector, an additional dimension must be added, e.g., np.atleast2d(a).T achieves this, as does a[:, np.newaxis]. For a 2-D array, this is the standard matrix transpose. For an n-D array, if axes are given, their order indicates how the axes are permuted (see Examples). If axes are not provided, then transpose(a).shape == a.shape[::-1].

Parameters:
  • a (array_like) – Input array.

  • axes (tuple or list of ints, optional) – If specified, it must be a tuple or list which contains a permutation of [0,1,…,N-1] where N is the number of axes of a. The i’th axis of the returned array will correspond to the axis numbered axes[i] of the input. If not specified, defaults to range(a.ndim)[::-1], which reverses the order of the axes.

Returns:

p – a with its axes permuted. A view is returned whenever possible.

Return type:

ndarray