jax.numpy.matrix_transpose#
- jax.numpy.matrix_transpose(x, /)[source]#
Transpose the last two dimensions of an array.
JAX implementation of
numpy.matrix_transpose()
, implemented in terms ofjax.lax.transpose()
.- Parameters:
x (ArrayLike) – input array, Must have
x.ndim >= 2
- Returns:
matrix-transposed copy of the array.
- Return type:
See also
jax.Array.mT
: same operation accessed via anArray()
property.jax.numpy.transpose()
: general multi-axis transpose
Note
Unlike
numpy.matrix_transpose()
,jax.numpy.matrix_transpose()
will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.Examples
Here is a 2x2x2 matrix representing a batched 2x2 matrix:
>>> x = jnp.array([[[1, 2], ... [3, 4]], ... [[5, 6], ... [7, 8]]]) >>> jnp.matrix_transpose(x) Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)
For convenience, you can perform the same transpose via the
mT
property ofjax.Array
:>>> x.mT Array([[[1, 3], [2, 4]], [[5, 7], [6, 8]]], dtype=int32)