jax.numpy.linalg.tensorinv

jax.numpy.linalg.tensorinv(a, ind=2)[source]

Compute the ‘inverse’ of an N-dimensional array.

LAX-backend implementation of tensorinv().

Original docstring below.

The result is an inverse for a relative to the tensordot operation tensordot(a, b, ind), i. e., up to floating-point accuracy, tensordot(tensorinv(a), a, ind) is the “identity” tensor for the tensordot operation.

Parameters
  • a (array_like) – Tensor to ‘invert’. Its shape must be ‘square’, i. e., prod(a.shape[:ind]) == prod(a.shape[ind:]).

  • ind (int, optional) – Number of first indices that are involved in the inverse sum. Must be a positive integer, default is 2.

Returns

ba’s tensordot inverse, shape a.shape[ind:] + a.shape[:ind].

Return type

ndarray