jax.numpy.linalg.multi_dot

Contents

jax.numpy.linalg.multi_dot#

jax.numpy.linalg.multi_dot(arrays, *, precision=None)[source]#

Compute the dot product of two or more arrays in a single function call,

LAX-backend implementation of numpy.linalg.multi_dot().

Original docstring below.

while automatically selecting the fastest evaluation order.

multi_dot chains numpy.dot and uses optimal parenthesization of the matrices [1] [2]. Depending on the shapes of the matrices, this can speed up the multiplication a lot.

If the first argument is 1-D it is treated as a row vector. If the last argument is 1-D it is treated as a column vector. The other arguments must be 2-D.

Think of multi_dot as:

def multi_dot(arrays): return functools.reduce(np.dot, arrays)
Parameters:

arrays (sequence of array_like) – If the first argument is 1-D it is treated as row vector. If the last argument is 1-D it is treated as column vector. The other arguments must be 2-D.

Returns:

output – Returns the dot product of the supplied arrays.

Return type:

ndarray

References