jax.numpy.linalg.multi_dot#
- jax.numpy.linalg.multi_dot(arrays, *, precision=None)[source]#
Compute the dot product of two or more arrays in a single function call,
LAX-backend implementation of
numpy.linalg.multi_dot()
.Original docstring below.
while automatically selecting the fastest evaluation order.
multi_dot chains numpy.dot and uses optimal parenthesization of the matrices [1] [2]. Depending on the shapes of the matrices, this can speed up the multiplication a lot.
If the first argument is 1-D it is treated as a row vector. If the last argument is 1-D it is treated as a column vector. The other arguments must be 2-D.
Think of multi_dot as:
def multi_dot(arrays): return functools.reduce(np.dot, arrays)
- Parameters:
arrays (sequence of array_like) – If the first argument is 1-D it is treated as row vector. If the last argument is 1-D it is treated as column vector. The other arguments must be 2-D.
- Returns:
output – Returns the dot product of the supplied arrays.
- Return type:
ndarray
References