jax.numpy.matrix_transpose

jax.numpy.matrix_transpose#

jax.numpy.matrix_transpose(x, /)[source]#

Transposes the last two dimensions of x.

Parameters:

x (array_like) – Input array. Must have x.ndim >= 2.

Returns:

xT – Transposed array.

Return type:

Array