jax.numpy.dot

Contents

jax.numpy.dot#

jax.numpy.dot(a, b, *, precision=None, preferred_element_type=None)[source]#

Compute the dot product of two arrays.

JAX implementation of numpy.dot().

This differs from jax.numpy.matmul() in two respects:

  • if either a or b is a scalar, the result of dot is equivalent to jax.numpy.multiply(), while the result of matmul is an error.

  • if a and b have more than 2 dimensions, the batch indices are stacked rather than broadcast.

Parameters:
  • a (ArrayLike) – first input array, of shape (..., N).

  • b (ArrayLike) – second input array. Must have shape (N,) or (..., N, M). In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions of a.

  • precision (PrecisionLike) – either None (default), which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two such values indicating precision of a and b.

  • preferred_element_type (DTypeLike | None) – either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

Returns:

array containing the dot product of the inputs, with batch dimensions of a and b stacked rather than broadcast.

Return type:

Array

See also

Examples

For scalar inputs, dot computes the element-wise product:

>>> x = jnp.array([1, 2, 3])
>>> jnp.dot(x, 2)
Array([2, 4, 6], dtype=int32)

For vector or matrix inputs, dot computes the vector or matrix product:

>>> M = jnp.array([[2, 3, 4],
...                [5, 6, 7],
...                [8, 9, 0]])
>>> jnp.dot(M, x)
Array([20, 38, 26], dtype=int32)
>>> jnp.dot(M, M)
Array([[ 51,  60,  29],
       [ 96, 114,  62],
       [ 61,  78,  95]], dtype=int32)

For higher-dimensional matrix products, batch dimensions are stacked, whereas in matmul() they are broadcast. For example:

>>> a = jnp.zeros((3, 2, 4))
>>> b = jnp.zeros((3, 4, 1))
>>> jnp.dot(a, b).shape
(3, 2, 3, 1)
>>> jnp.matmul(a, b).shape
(3, 2, 1)