jax.numpy.inner

Contents

jax.numpy.inner#

jax.numpy.inner(a, b, *, precision=None, preferred_element_type=None)[source]#

Inner product of two arrays.

LAX-backend implementation of numpy.inner().

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two Precision enums indicating separate precision for each argument.

Original docstring below.

Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.

Parameters:
  • a (array_like) – If a and b are nonscalar, their last dimensions must match.

  • b (array_like) – If a and b are nonscalar, their last dimensions must match.

  • preferred_element_type (dtype, optional) – If specified, accumulate results and return a result of the given data type. If not specified, the accumulation dtype is determined from the type promotion rules of the input array dtypes.

  • precision (str | Precision | tuple[str, str] | tuple[Precision, Precision] | None)

Returns:

out – If a and b are both scalars or both 1-D arrays then a scalar is returned; otherwise an array is returned. out.shape = (*a.shape[:-1], *b.shape[:-1])

Return type:

ndarray