jax.numpy.matmulΒΆ

jax.numpy.matmul(a, b, *, precision=None)[source]ΒΆ

Matrix product of two arrays.

LAX-backend implementation of matmul().

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two lax.Precision enums indicating separate precision for each argument.

Original docstring below.

Parameters
  • x1 (array_like) – Input arrays, scalars not allowed.

  • x2 (array_like) – Input arrays, scalars not allowed.

  • out (ndarray, optional) – A location into which the result is stored. If provided, it must have a shape that matches the signature (n,k),(k,m)->(n,m). If not provided or None, a freshly-allocated array is returned.

  • **kwargs –

    For other keyword-only arguments, see the ufunc docs.

    New in version 1.16: Now handles ufunc kwargs

Returns

y – The matrix product of the inputs. This is a scalar only when both x1, x2 are 1-d vectors.

Return type

ndarray