jax.numpy.linalg.multi_dot

jax.numpy.linalg.multi_dot(arrays, *, precision=None)[source]

Compute the dot product of two or more arrays in a single function call,

LAX-backend implementation of multi_dot().

Original docstring below.

while automatically selecting the fastest evaluation order.

multi_dot chains numpy.dot and uses optimal parenthesization of the matrices 1 2. Depending on the shapes of the matrices, this can speed up the multiplication a lot.

If the first argument is 1-D it is treated as a row vector. If the last argument is 1-D it is treated as a column vector. The other arguments must be 2-D.

Think of multi_dot as:

def multi_dot(arrays): return functools.reduce(np.dot, arrays)
Parameters

arrays (sequence of array_like) – If the first argument is 1-D it is treated as row vector. If the last argument is 1-D it is treated as column vector. The other arguments must be 2-D.

Returns

output – Returns the dot product of the supplied arrays.

Return type

ndarray

References

1

Cormen, “Introduction to Algorithms”, Chapter 15.2, p. 370-378

2

https://en.wikipedia.org/wiki/Matrix_chain_multiplication