jax.numpy.linalg.tensorinv

Contents

jax.numpy.linalg.tensorinv#

jax.numpy.linalg.tensorinv(a, ind=2)[source]#

Compute the ‘inverse’ of an N-dimensional array.

LAX-backend implementation of numpy.linalg.tensorinv().

Original docstring below.

The result is an inverse for a relative to the tensordot operation tensordot(a, b, ind), i. e., up to floating-point accuracy, tensordot(tensorinv(a), a, ind) is the “identity” tensor for the tensordot operation.

Parameters:
  • a (array_like) – Tensor to ‘invert’. Its shape must be ‘square’, i. e., prod(a.shape[:ind]) == prod(a.shape[ind:]).

  • ind (int, optional) – Number of first indices that are involved in the inverse sum. Must be a positive integer, default is 2.

Returns:

b – a’s tensordot inverse, shape a.shape[ind:] + a.shape[:ind].

Return type:

ndarray