jax.numpy.vdot

Contents

jax.numpy.vdot#

jax.numpy.vdot(a, b, *, precision=None, preferred_element_type=None)[source]#

Perform a conjugate multiplication of two 1D vectors.

JAX implementation of numpy.vdot().

Parameters:
  • a (ArrayLike) – first input array, if not 1D it will be flattened.

  • b (ArrayLike) – second input array, if not 1D it will be flattened. Must have a.size == b.size.

  • precision (PrecisionLike) – either None (default), which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two such values indicating precision of a and b.

  • preferred_element_type (DTypeLike | None) – either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

Returns:

Scalar array (shape ()) containing the conjugate vector product of the inputs.

Return type:

Array

See also

Examples

>>> x = jnp.array([1j, 2j, 3j])
>>> y = jnp.array([1., 2., 3.])
>>> jnp.vdot(x, y)
Array(0.-14.j, dtype=complex64)

Note the difference between this and dot(), which does not conjugate the first input when complex:

>>> jnp.dot(x, y)
Array(0.+14.j, dtype=complex64)