- jax.numpy.dot(a, b, *, precision=None)[source]#
Dot product of two arrays. Specifically,
LAX-backend implementation of
In addition to the original NumPy arguments listed below, also supports
precisionfor extra control over matrix-multiplication precision on supported devices.
precisionmay be set to
None, which means default precision for the backend, a
Precisionenum value (
Precision.HIGHEST) or a tuple of two
Precisionenums indicating separate precision for each argument.
Original docstring below.
If both a and b are 1-D arrays, it is inner product of vectors (without complex conjugation).
If both a and b are 2-D arrays, it is matrix multiplication, but using
a @ bis preferred.
If either a or b is 0-D (scalar), it is equivalent to
a * bis preferred.
If a is an N-D array and b is a 1-D array, it is a sum product over the last axis of a and b.
If a is an N-D array and b is an M-D array (where
M>=2), it is a sum product over the last axis of a and the second-to-last axis of b:
dot(a, b)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])
It uses an optimized BLAS library when possible (see numpy.linalg).
a (array_like) – First argument.
b (array_like) – Second argument.
output – Returns the dot product of a and b. If a and b are both scalars or both 1-D arrays then a scalar is returned; otherwise an array is returned. If out is given, then it is returned.
- Return type: