jax.numpy.einsum

Contents

jax.numpy.einsum#

jax.numpy.einsum(subscript: str, /, *operands: Array | ndarray | bool_ | number | bool | int | float | complex, out: None = None, optimize: str = 'optimal', precision: str | Precision | tuple[str, str] | tuple[Precision, Precision] | None = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[[...], Array] = lax.dot_general) Array[source]#
jax.numpy.einsum(arr: Array | ndarray | bool_ | number | bool | int | float | complex, axes: Sequence[Any], /, *operands: ArrayLike | Sequence[Any], out: None = None, optimize: str = 'optimal', precision: str | Precision | tuple[str, str] | tuple[Precision, Precision] | None = None, preferred_element_type: DTypeLike | None = None, _dot_general: Callable[[...], Array] = lax.dot_general) Array

Evaluates the Einstein summation convention on the operands.

LAX-backend implementation of numpy.einsum().

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two Precision enums indicating separate precision for each argument. A tuple precision does not necessarily map to multiple arguments of einsum(); rather, the specified precision is forwarded to each dot_general call used in the implementation.

jax.numpy.einsum() also differs from numpy.einsum() in that the optimize keyword defaults to "optimal" rather than False.

Original docstring below.

Using the Einstein summation convention, many common multi-dimensional, linear algebraic array operations can be represented in a simple fashion. In implicit mode einsum computes these values.

In explicit mode, einsum provides further flexibility to compute other array operations that might not be considered classical Einstein summation operations, by disabling, or forcing summation over specified subscript labels.

See the notes and examples for clarification.

Parameters:
  • subscripts (str) – Specifies the subscripts for summation as comma separated list of subscript labels. An implicit (classical Einstein summation) calculation is performed unless the explicit indicator ‘->’ is included as well as subscript labels of the precise output form.

  • operands (list of array_like) – These are the arrays for the operation.

  • optimize ({False, True, 'greedy', 'optimal'}, optional) – Controls if intermediate optimization should occur. No optimization will occur if False and True will default to the ‘greedy’ algorithm. Also accepts an explicit contraction list from the np.einsum_path function. See np.einsum_path for more details. Defaults to False.

Returns:

output – The calculation based on the Einstein summation convention.

Return type:

ndarray