jax.numpy.linalg.tensorsolve

Contents

jax.numpy.linalg.tensorsolve#

jax.numpy.linalg.tensorsolve(a, b, axes=None)[source]#

Solve the tensor equation a x = b for x.

JAX implementation of numpy.linalg.tensorsolve().

Parameters:
  • a (jax.typing.ArrayLike) – input array. After reordering via axes (see below), shape must be (*b.shape, *x.shape).

  • b (jax.typing.ArrayLike) – right-hand-side array.

  • axes (tuple[int, ...] | None) – optional tuple specifying axes of a that should be moved to the end

Returns:

array x such that after reordering of axes of a, tensordot(a, x, x.ndim) is equivalent to b.

Return type:

Array

Examples

>>> key1, key2 = jax.random.split(jax.random.key(8675309))
>>> a = jax.random.normal(key1, shape=(2, 2, 4))
>>> b = jax.random.normal(key2, shape=(2, 2))
>>> x = jnp.linalg.tensorsolve(a, b)
>>> x.shape
(4,)

Now show that x can be used to reconstruct b using tensordot():

>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim)
>>> jnp.allclose(b, b_reconstructed)
Array(True, dtype=bool)