# jax.numpy.linalg.multi_dot#

jax.numpy.linalg.multi_dot(arrays, *, precision=None)[source]#

Efficiently compute matrix products between a sequence of arrays.

JAX implementation of `numpy.linalg.multi_dot()`.

JAX internally uses the opt_einsum library to compute the most efficient operation order.

Parameters:
• arrays (Sequence[jax.typing.ArrayLike]) â€“ sequence of arrays. All must be two-dimensional, except the first and last which may be one-dimensional.

• precision (str | Precision | tuple[str, str] | tuple[Precision, Precision] | None) â€“ either `None` (default), which means the default precision for the backend, a `Precision` enum value (`Precision.DEFAULT`, `Precision.HIGH` or `Precision.HIGHEST`).

Returns:

an array representing the equivalent of `reduce(jnp.matmul, arrays)`, but evaluated in the optimal order.

Return type:

Array

This function exists because the cost of computing sequences of matmul operations can differ vastly depending on the order in which the operations are evaluated. For a single matmul, the number of floating point operations (flops) required to compute a matrix product can be approximated this way:

```>>> def approx_flops(x, y):
...   # for 2D x and y, with x.shape[1] == y.shape[0]
...   return 2 * x.shape[0] * x.shape[1] * y.shape[1]
```

Suppose we have three matrices that weâ€™d like to multiply in sequence:

```>>> key1, key2, key3 = jax.random.split(jax.random.key(0), 3)
>>> x = jax.random.normal(key1, shape=(200, 5))
>>> y = jax.random.normal(key2, shape=(5, 100))
>>> z = jax.random.normal(key3, shape=(100, 10))
```

Because of associativity of matrix products, there are two orders in which we might evaluate the product `x @ y @ z`, and both produce equivalent outputs up to floating point precision:

```>>> result1 = (x @ y) @ z
>>> result2 = x @ (y @ z)
>>> jnp.allclose(result1, result2, atol=1E-4)
Array(True, dtype=bool)
```

But the computational cost of these differ greatly:

```>>> print("(x @ y) @ z flops:", approx_flops(x, y) + approx_flops(x @ y, z))
(x @ y) @ z flops: 600000
>>> print("x @ (y @ z) flops:", approx_flops(y, z) + approx_flops(x, y @ z))
x @ (y @ z) flops: 30000
```

The second approach is about 20x more efficient in terms of estimated flops!

`multi_dot` is a function that will automatically choose the fastest computational path for such problems:

```>>> result3 = jnp.linalg.multi_dot([x, y, z])
>>> jnp.allclose(result1, result3, atol=1E-4)
Array(True, dtype=bool)
```

We can use JAXâ€™s Ahead-of-time lowering and compilation tools to estimate the total flops of each approach, and confirm that `multi_dot` is choosing the more efficient option:

```>>> jax.jit(lambda x, y, z: (x @ y) @ z).lower(x, y, z).cost_analysis()['flops']
600000.0
>>> jax.jit(lambda x, y, z: x @ (y @ z)).lower(x, y, z).cost_analysis()['flops']
30000.0
>>> jax.jit(jnp.linalg.multi_dot).lower([x, y, z]).cost_analysis()['flops']
30000.0
```