jax.numpy.outer#
- jax.numpy.outer(a, b, out=None)[source]#
Compute the outer product of two arrays.
JAX implementation of
numpy.outer()
.- Parameters:
a (ArrayLike) – first input array, if not 1D it will be flattened.
b (ArrayLike) – second input array, if not 1D it will be flattened.
out (None) – unsupported by JAX.
- Returns:
The outer product of the inputs
a
andb
. Returned array will be of shape(a.size, b.size)
.- Return type:
See also
jax.numpy.inner()
: compute the inner product of two arrays.jax.numpy.einsum()
: Einstein summation.
Examples
>>> a = jnp.array([1, 2, 3]) >>> b = jnp.array([4, 5, 6]) >>> jnp.outer(a, b) Array([[ 4, 5, 6], [ 8, 10, 12], [12, 15, 18]], dtype=int32)