jax.numpy.linalg.tensorinv

Contents

jax.numpy.linalg.tensorinv#

jax.numpy.linalg.tensorinv(a, ind=2)[source]#

Compute the tensor inverse of an array.

JAX implementation of numpy.linalg.tensorinv().

This computes the inverse of the tensordot() operation with the same ind value.

Parameters:
  • a (jax.typing.ArrayLike) – array to be inverted. Must have prod(a.shape[:ind]) == prod(a.shape[ind:])

  • ind (int) – positive integer specifying the number of indices in the tensor product.

Returns:

array of shape (*a.shape[ind:], *a.shape[:ind]) containing the tensor inverse of a.

Return type:

Array

Example

>>> key = jax.random.key(1337)
>>> x = jax.random.normal(key, shape=(2, 2, 4))
>>> xinv = jnp.linalg.tensorinv(x, 2)
>>> xinv_x = jnp.linalg.tensordot(xinv, x, axes=2)
>>> jnp.allclose(xinv_x, jnp.eye(4), atol=1E-4)
Array(True, dtype=bool)