jax.numpy.transpose#
- jax.numpy.transpose(a, axes=None)[source]#
Returns an array with axes transposed.
LAX-backend implementation of
numpy.transpose()
.The JAX version of this function may in some cases return a copy rather than a view of the input.
Original docstring below.
For a 1-D array, this returns an unchanged view of the original array, as a transposed vector is simply the same vector. To convert a 1-D array into a 2-D column vector, an additional dimension must be added, e.g.,
np.atleast2d(a).T
achieves this, as doesa[:, np.newaxis]
. For a 2-D array, this is the standard matrix transpose. For an n-D array, if axes are given, their order indicates how the axes are permuted (see Examples). If axes are not provided, thentranspose(a).shape == a.shape[::-1]
.- Parameters:
a (array_like) – Input array.
axes (tuple or list of ints, optional) – If specified, it must be a tuple or list which contains a permutation of [0,1,…,N-1] where N is the number of axes of a. The i’th axis of the returned array will correspond to the axis numbered
axes[i]
of the input. If not specified, defaults torange(a.ndim)[::-1]
, which reverses the order of the axes.
- Returns:
p – a with its axes permuted. A view is returned whenever possible.
- Return type:
ndarray