jax.numpy.innerΒΆ

jax.numpy.inner(a, b, *, precision=None)[source]ΒΆ

Inner product of two arrays.

LAX-backend implementation of inner().

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two lax.Precision enums indicating separate precision for each argument.

Original docstring below.

Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.

Parameters
  • a (array_like) – If a and b are nonscalar, their last dimensions must match.

  • b (array_like) – If a and b are nonscalar, their last dimensions must match.

Returns

out – out.shape = a.shape[:-1] + b.shape[:-1]

Return type

ndarray