jax.numpy.dot(a, b, *, precision=None)[source]

Dot product of two arrays. Specifically,

LAX-backend implementation of dot(). In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, or any jax.lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST).

Original docstring below.

dot(a, b, out=None)

  • If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation).

  • If both a and b are 2-D arrays, it is matrix multiplication, but using matmul() or a @ b is preferred.

  • If either a or b is 0-D (scalar), it is equivalent to multiply() and using numpy.multiply(a, b) or a * b is preferred.

  • If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b.

  • If a is an N-D array and b is an M-D array (where M>=2), it is a sum product over the last axis of a and the second-to-last axis of b:

    dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])

Returns the dot product of a and b. If a and b are both scalars or both 1-D arrays then a scalar is returned; otherwise an array is returned. If out is given, then it is returned.


If the last dimension of a is not the same size as the second-to-last dimension of b.

vdot : Complex-conjugating dot product. tensordot : Sum products over arbitrary axes. einsum : Einstein summation convention. matmul : ‘@’ operator as method with out parameter.

>>> np.dot(3, 4)

Neither argument is complex-conjugated:

>>> np.dot([2j, 3j], [2j, 3j])

For 2-D arrays it is the matrix product:

>>> a = [[1, 0], [0, 1]]
>>> b = [[4, 1], [2, 2]]
>>> np.dot(a, b)
array([[4, 1],
       [2, 2]])
>>> a = np.arange(3*4*5*6).reshape((3,4,5,6))
>>> b = np.arange(3*4*5*6)[::-1].reshape((5,4,6,3))
>>> np.dot(a, b)[2,3,2,1,2,2]
>>> sum(a[2,3,2,:] * b[1,2,:,2])