jax.numpy.linalg.tensorsolve#
- jax.numpy.linalg.tensorsolve(a, b, axes=None)[source]#
Solve the tensor equation a x = b for x.
JAX implementation of
numpy.linalg.tensorsolve()
.- Parameters:
- Returns:
array x such that after reordering of axes of
a
,tensordot(a, x, x.ndim)
is equivalent tob
.- Return type:
Examples
>>> key1, key2 = jax.random.split(jax.random.key(8675309)) >>> a = jax.random.normal(key1, shape=(2, 2, 4)) >>> b = jax.random.normal(key2, shape=(2, 2)) >>> x = jnp.linalg.tensorsolve(a, b) >>> x.shape (4,)
Now show that
x
can be used to reconstructb
usingtensordot()
:>>> b_reconstructed = jnp.linalg.tensordot(a, x, axes=x.ndim) >>> jnp.allclose(b, b_reconstructed) Array(True, dtype=bool)