jax.numpy.outer

Contents

jax.numpy.outer#

jax.numpy.outer(a, b, out=None)#

Compute the outer product of two vectors.

LAX-backend implementation of numpy.outer().

Original docstring below.

Given two vectors a and b of length M and N, repsectively, the outer product [1] is:

[[a_0*b_0  a_0*b_1 ... a_0*b_{N-1} ]
 [a_1*b_0    .
 [ ...          .
 [a_{M-1}*b_0            a_{M-1}*b_{N-1} ]]
Parameters:
  • a ((M,) array_like) – First input vector. Input is flattened if not already 1-dimensional.

  • b ((N,) array_like) – Second input vector. Input is flattened if not already 1-dimensional.

  • out (None)

Returns:

out – out[i, j] = a[i] * b[j]

Return type:

(M, N) ndarray

References