jax.numpy.linalg.matrix_transpose

jax.numpy.linalg.matrix_transpose#

jax.numpy.linalg.matrix_transpose(x, /)[source]#

Transpose a matrix or stack of matrices.

JAX implementation of numpy.linalg.matrix_transpose().

Parameters:

x (jax.typing.ArrayLike) – array of shape (..., M, N)

Returns:

array of shape (..., N, M) containing the matrix transpose of x.

Return type:

Array

See also

jax.numpy.transpose(): more general transpose operation.

Examples

Transpose of a single matrix:

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.linalg.matrix_transpose(x)
Array([[1, 4],
       [2, 5],
       [3, 6]], dtype=int32)

Transpose of a stack of matrices:

>>> x = jnp.array([[[1, 2],
...                 [3, 4]],
...                [[5, 6],
...                 [7, 8]]])
>>> jnp.linalg.matrix_transpose(x)
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

For convenience, the same computation can be done via the mT property of JAX array objects:

>>> x.mT
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)