jax.numpy.vdot

jax.numpy.vdot(a, b, *, precision=None)[source]

Return the dot product of two vectors.

LAX-backend implementation of vdot(). In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, or any jax.lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST).

Original docstring below.

vdot(a, b)

The vdot(a, b) function handles complex numbers differently than dot(a, b). If the first argument is complex the complex conjugate of the first argument is used for the calculation of the dot product.

Note that vdot handles multidimensional arrays differently than dot: it does not perform a matrix product, but flattens input arguments to 1-D vectors first. Consequently, it should only be used for vectors.

Returns
outputndarray

Dot product of a and b. Can be an int, float, or complex depending on the types of a and b.

dotReturn the dot product without using the complex conjugate of the

first argument.

>>> a = np.array([1+2j,3+4j])
>>> b = np.array([5+6j,7+8j])
>>> np.vdot(a, b)
(70-8j)
>>> np.vdot(b, a)
(70+8j)

Note that higher-dimensional arrays are flattened!

>>> a = np.array([[1, 4], [5, 6]])
>>> b = np.array([[4, 1], [2, 2]])
>>> np.vdot(a, b)
30
>>> np.vdot(b, a)
30
>>> 1*4 + 4*1 + 5*2 + 6*2
30