jax.numpy.einsum_path

jax.numpy.einsum_path(subscripts, *operands, optimize='greedy')[source]

Evaluates the lowest cost contraction order for an einsum expression by considering the creation of intermediate arrays.

LAX-backend implementation of einsum_path(). Original docstring below.

einsum_path(subscripts, *operands, optimize=’greedy’)

Parameters
  • subscripts (str) – Specifies the subscripts for summation.

  • optimize ({bool, list, tuple, 'greedy', 'optimal'}) – Choose the type of path. If a tuple is provided, the second argument is assumed to be the maximum intermediate size created. If only a single argument is provided the largest input or output array size is used as a maximum intermediate size.

Returns

  • path (list of tuples) – A list representation of the einsum path.

  • string_repr (str) – A printable representation of the einsum path.

Notes

The resulting path indicates which terms of the input contraction should be contracted first, the result of this contraction is then appended to the end of the contraction list. This list can then be iterated over until all intermediate contractions are complete.

Examples

We can begin with a chain dot example. In this case, it is optimal to contract the b and c tensors first as represented by the first element of the path (1, 2). The resulting tensor is added to the end of the contraction and the remaining contraction (0, 1) is then completed.

>>> np.random.seed(123)
>>> a = np.random.rand(2, 2)
>>> b = np.random.rand(2, 5)
>>> c = np.random.rand(5, 2)
>>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
>>> print(path_info[0])
['einsum_path', (1, 2), (0, 1)]
>>> print(path_info[1])
  Complete contraction:  ij,jk,kl->il # may vary
         Naive scaling:  4
     Optimized scaling:  3
      Naive FLOP count:  1.600e+02
  Optimized FLOP count:  5.600e+01
   Theoretical speedup:  2.857
  Largest intermediate:  4.000e+00 elements
-------------------------------------------------------------------------
scaling                  current                                remaining
-------------------------------------------------------------------------
   3                   kl,jk->jl                                ij,jl->il
   3                   jl,ij->il                                   il->il

A more complex index transformation example.

>>> I = np.random.rand(10, 10, 10, 10)
>>> C = np.random.rand(10, 10)
>>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
...                            optimize='greedy')
>>> print(path_info[0])
['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
>>> print(path_info[1])
  Complete contraction:  ea,fb,abcd,gc,hd->efgh # may vary
         Naive scaling:  8
     Optimized scaling:  5
      Naive FLOP count:  8.000e+08
  Optimized FLOP count:  8.000e+05
   Theoretical speedup:  1000.000
  Largest intermediate:  1.000e+04 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   5               abcd,ea->bcde                      fb,gc,hd,bcde->efgh
   5               bcde,fb->cdef                         gc,hd,cdef->efgh
   5               cdef,gc->defg                            hd,defg->efgh
   5               defg,hd->efgh                               efgh->efgh