jax.numpy.matmul#
- jax.numpy.matmul(a, b, *, precision=None, preferred_element_type=None)[source]#
Matrix product of two arrays.
LAX-backend implementation of
numpy.matmul()
.In addition to the original NumPy arguments listed below, also supports
precision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twoPrecision
enums indicating separate precision for each argument.Original docstring below.
- Parameters:
out (ndarray, optional) – A location into which the result is stored. If provided, it must have a shape that matches the signature (n,k),(k,m)->(n,m). If not provided or None, a freshly-allocated array is returned.
**kwargs – For other keyword-only arguments, see the ufunc docs.
preferred_element_type (dtype, optional) – If specified, accumulate results and return a result of the given data type. If not specified, the accumulation dtype is determined from the type promotion rules of the input array dtypes.
a (ArrayLike) –
b (ArrayLike) –
precision (PrecisionLike) –
- Returns:
y – The matrix product of the inputs. This is a scalar only when both x1, x2 are 1-d vectors.
- Return type:
ndarray