jax.numpy.vecdot

Contents

jax.numpy.vecdot#

jax.numpy.vecdot(x1, x2, /, *, axis=-1, precision=None, preferred_element_type=None)[source]#

Perform a conjugate multiplication of two batched vectors.

JAX implementation of numpy.vecdot().

Parameters:
  • a – left-hand side array.

  • b – right-hand side array. Size of b[axis] must match size of a[axis], and remaining dimensions must be broadcast-compatible.

  • axis (int) – axis along which to compute the dot product (default: -1)

  • precision (PrecisionLike) – either None (default), which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two such values indicating precision of a and b.

  • preferred_element_type (DTypeLike | None) – either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

  • x1 (ArrayLike)

  • x2 (ArrayLike)

Returns:

array containing the conjugate dot product of a and b along axis. The non-contracted dimensions are broadcast together.

Return type:

Array

See also

Examples

Vector conjugate-dot product of two 1D arrays:

>>> a = jnp.array([1j, 2j, 3j])
>>> b = jnp.array([4., 5., 6.])
>>> jnp.linalg.vecdot(a, b)
Array(0.-32.j, dtype=complex64)

Batched vector dot product of two 2D arrays:

>>> a = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> b = jnp.array([[2, 3, 4]])
>>> jnp.linalg.vecdot(a, b, axis=-1)
Array([20, 47], dtype=int32)