jax.numpy.linalg.outer

Contents

jax.numpy.linalg.outer#

jax.numpy.linalg.outer(x1, x2, /)[source]#

Compute the outer product of two 1-dimensional arrays.

JAX implementation of numpy.linalg.outer().

Parameters:
  • x1 (jax.typing.ArrayLike) – array

  • x2 (jax.typing.ArrayLike) – array

Returns:

array containing the outer product of x1 and x2

Return type:

Array

See also

jax.numpy.outer(): similar function in the main jax.numpy module.

Example

>>> x1 = jnp.array([1, 2, 3])
>>> x2 = jnp.array([4, 5, 6])
>>> jnp.linalg.outer(x1, x2)
Array([[ 4,  5,  6],
       [ 8, 10, 12],
       [12, 15, 18]], dtype=int32)