jax.numpy.einsum

jax.numpy.einsum(*operands, out=None, optimize='optimal', precision=None, _use_xeinsum=False)[source]

Evaluates the Einstein summation convention on the operands.

LAX-backend implementation of einsum().

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two lax.Precision enums indicating separate precision for each argument.

Original docstring below.

Using the Einstein summation convention, many common multi-dimensional, linear algebraic array operations can be represented in a simple fashion. In implicit mode einsum computes these values.

In explicit mode, einsum provides further flexibility to compute other array operations that might not be considered classical Einstein summation operations, by disabling, or forcing summation over specified subscript labels.

See the notes and examples for clarification.

Parameters
  • operands (list of array_like) – These are the arrays for the operation.

  • optimize ({False, True, 'greedy', 'optimal'}, optional) – Controls if intermediate optimization should occur. No optimization will occur if False and True will default to the ‘greedy’ algorithm. Also accepts an explicit contraction list from the np.einsum_path function. See np.einsum_path for more details. Defaults to False.

Returns

output – The calculation based on the Einstein summation convention.

Return type

ndarray