jax.numpy.vdot

Contents

jax.numpy.vdot#

jax.numpy.vdot(a, b, *, precision=None, preferred_element_type=None)[source]#

Return the dot product of two vectors.

LAX-backend implementation of numpy.vdot().

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two Precision enums indicating separate precision for each argument.

Original docstring below.

The vdot(a, b) function handles complex numbers differently than dot(a, b). If the first argument is complex the complex conjugate of the first argument is used for the calculation of the dot product.

Note that vdot handles multidimensional arrays differently than dot: it does not perform a matrix product, but flattens input arguments to 1-D vectors first. Consequently, it should only be used for vectors.

Parameters:
  • a (array_like) – If a is complex the complex conjugate is taken before calculation of the dot product.

  • b (array_like) – Second argument to the dot product.

  • preferred_element_type (dtype, optional) – If specified, accumulate results and return a result of the given data type. If not specified, the accumulation dtype is determined from the type promotion rules of the input array dtypes.

  • precision (PrecisionLike)

Returns:

output – Dot product of a and b. Can be an int, float, or complex depending on the types of a and b.

Return type:

ndarray