jax.numpy.linalg.matrix_transpose#
- jax.numpy.linalg.matrix_transpose(x, /)[source]#
Transpose a matrix or stack of matrices.
JAX implementation of
numpy.linalg.matrix_transpose()
.- Parameters:
x (ArrayLike) – array of shape
(..., M, N)
- Returns:
array of shape
(..., N, M)
containing the matrix transpose ofx
.- Return type:
See also
jax.numpy.transpose()
: more general transpose operation.Examples
Transpose of a single matrix:
>>> x = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> jnp.linalg.matrix_transpose(x) Array([[1, 4], [2, 5], [3, 6]], dtype=int32)
Transpose of a stack of matrices:
>>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.linalg.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
For convenience, the same computation can be done via the
mT
property of JAX array objects:>>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)