jax.numpy.linalg.matmul#
- jax.numpy.linalg.matmul(x1, x2, /, *, precision=None, preferred_element_type=None)[source]#
Perform a matrix multiplication.
JAX implementation of
numpy.linalg.matmul()
.- Parameters:
x1 (ArrayLike) – first input array, of shape
(..., N)
.x2 (ArrayLike) – second input array. Must have shape
(N,)
or(..., N, M)
. In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions ofx1
.precision (PrecisionLike | None) – either
None
(default), which means the default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of two such values indicating precision ofx1
andx2
.preferred_element_type (DTypeLike | None | None) – either
None
(default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Returns:
array containing the matrix product of the inputs. Shape is
x1.shape[:-1]
ifx2.ndim == 1
, otherwise the shape is(..., M)
.- Return type:
See also
jax.numpy.matmul()
: NumPy API for this function.jax.numpy.linalg.vecdot()
: batched vector product.jax.numpy.linalg.tensordot()
: batched tensor product.Examples
Vector dot products:
>>> x1 = jnp.array([1, 2, 3]) >>> x2 = jnp.array([4, 5, 6]) >>> jnp.linalg.matmul(x1, x2) Array(32, dtype=int32)
Matrix dot product:
>>> x1 = jnp.array([[1, 2, 3], ... [4, 5, 6]]) >>> x2 = jnp.array([[1, 2], ... [3, 4], ... [5, 6]]) >>> jnp.linalg.matmul(x1, x2) Array([[22, 28], [49, 64]], dtype=int32)
For convenience, in all cases you can do the same computation using the
@
operator:>>> x1 @ x2 Array([[22, 28], [49, 64]], dtype=int32)