jax.numpy.einsum_path

Contents

jax.numpy.einsum_path#

jax.numpy.einsum_path(subscripts, *operands, optimize='greedy')[source]#

Evaluates the lowest cost contraction order for an einsum expression by

LAX-backend implementation of numpy.einsum_path().

Original docstring below.

considering the creation of intermediate arrays.

Parameters:
  • subscripts (str) – Specifies the subscripts for summation.

  • *operands (list of array_like) – These are the arrays for the operation.

  • optimize ({bool, list, tuple, 'greedy', 'optimal'}) –

    Choose the type of path. If a tuple is provided, the second argument is assumed to be the maximum intermediate size created. If only a single argument is provided the largest input or output array size is used as a maximum intermediate size.

    • if a list is given that starts with einsum_path, uses this as the contraction path

    • if False no optimization is taken

    • if True defaults to the ‘greedy’ algorithm

    • ’optimal’ An algorithm that combinatorially explores all possible ways of contracting the listed tensors and chooses the least costly path. Scales exponentially with the number of terms in the contraction.

    • ’greedy’ An algorithm that chooses the best pair contraction at each step. Effectively, this algorithm searches the largest inner, Hadamard, and then outer products at each step. Scales cubically with the number of terms in the contraction. Equivalent to the ‘optimal’ path for most contractions.

    Default is ‘greedy’.

Returns:

  • path (list of tuples) – A list representation of the einsum path.

  • string_repr (str) – A printable representation of the einsum path.