jax.numpy.outer

Contents

jax.numpy.outer#

jax.numpy.outer(a, b, out=None)[source]#

Compute the outer product of two arrays.

JAX implementation of numpy.outer().

Parameters:
  • a (ArrayLike) – first input array, if not 1D it will be flattened.

  • b (ArrayLike) – second input array, if not 1D it will be flattened.

  • out (None) – unsupported by JAX.

Returns:

The outer product of the inputs a and b. Returned array will be of shape (a.size, b.size).

Return type:

Array

See also

Examples

>>> a = jnp.array([1, 2, 3])
>>> b = jnp.array([4, 5, 6])
>>> jnp.outer(a, b)
Array([[ 4,  5,  6],
       [ 8, 10, 12],
       [12, 15, 18]], dtype=int32)