jax.numpy.matrix_transpose

jax.numpy.matrix_transpose#

jax.numpy.matrix_transpose(x, /)[source]#

Transpose the last two dimensions of an array.

JAX implementation of jax.numpy.matrix_transpose(), implemented in terms of jax.lax.transpose().

Parameters:

x (jax.typing.ArrayLike) – input array, Must have x.ndim >= 2

Returns:

matrix-transposed copy of the array.

Return type:

Array

See also

Note

Unlike numpy.matrix_transpose(), jax.numpy.matrix_transpose() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

Examples

Here is a 2x2x2 matrix representing a batched 2x2 matrix:

>>> x = jnp.array([[[1, 2],
...                 [3, 4]],
...                [[5, 6],
...                 [7, 8]]])
>>> jnp.matrix_transpose(x)
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)

For convenience, you can perform the same transpose via the mT property of jax.Array:

>>> x.mT
Array([[[1, 3],
        [2, 4]],

       [[5, 7],
        [6, 8]]], dtype=int32)