jax.numpy.matmul¶
-
jax.numpy.
matmul
(a, b, *, precision=None)[source]¶ Matrix product of two arrays.
LAX-backend implementation of
matmul()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
matmul(x1, x2, /, out=None, *, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])
- Parameters
out (ndarray, optional) – A location into which the result is stored. If provided, it must have a shape that matches the signature (n,k),(k,m)->(n,m). If not provided or None, a freshly-allocated array is returned.
- Returns
y – The matrix product of the inputs. This is a scalar only when both x1, x2 are 1-d vectors.
- Return type
- Raises
ValueError – If the last dimension of a is not the same size as the second-to-last dimension of b.
If a scalar value is passed in. –
See also
vdot()
Complex-conjugating dot product.
tensordot()
Sum products over arbitrary axes.
einsum()
Einstein summation convention.
dot()
alternative matrix product with different broadcasting rules.
Notes
The behavior depends on the arguments in the following way.
If both arguments are 2-D they are multiplied like conventional matrices.
If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly.
If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. After matrix multiplication the prepended 1 is removed.
If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. After matrix multiplication the appended 1 is removed.
matmul
differs fromdot
in two important ways:Multiplication by scalars is not allowed, use
*
instead.Stacks of matrices are broadcast together as if the matrices were elements, respecting the signature
(n,k),(k,m)->(n,m)
:
>>> a = np.ones([9, 5, 7, 4]) >>> c = np.ones([9, 5, 4, 3]) >>> np.dot(a, c).shape (9, 5, 7, 9, 5, 3) >>> np.matmul(a, c).shape (9, 5, 7, 3) >>> # n is 7, k is 4, m is 3
The matmul function implements the semantics of the @ operator introduced in Python 3.5 following PEP465.
Examples
For 2-D arrays it is the matrix product:
>>> a = np.array([[1, 0], ... [0, 1]]) >>> b = np.array([[4, 1], ... [2, 2]]) >>> np.matmul(a, b) array([[4, 1], [2, 2]])
For 2-D mixed with 1-D, the result is the usual.
>>> a = np.array([[1, 0], ... [0, 1]]) >>> b = np.array([1, 2]) >>> np.matmul(a, b) array([1, 2]) >>> np.matmul(b, a) array([1, 2])
Broadcasting is conventional for stacks of arrays
>>> a = np.arange(2 * 2 * 4).reshape((2, 2, 4)) >>> b = np.arange(2 * 2 * 4).reshape((2, 4, 2)) >>> np.matmul(a,b).shape (2, 2, 2) >>> np.matmul(a, b)[0, 1, 1] 98 >>> sum(a[0, 1, :] * b[0 , :, 1]) 98
Vector, vector returns the scalar inner product, but neither argument is complex-conjugated:
>>> np.matmul([2j, 3j], [2j, 3j]) (-13+0j)
Scalar multiplication raises an error.
>>> np.matmul([1,2], 3) Traceback (most recent call last): ... ValueError: matmul: Input operand 1 does not have enough dimensions ...
New in version 1.10.0.