jax.numpy.inner

Contents

jax.numpy.inner#

jax.numpy.inner(a, b, *, precision=None, preferred_element_type=None)[source]#

Compute the inner product of two arrays.

JAX implementation of numpy.inner().

Unlike jax.numpy.matmul() or jax.numpy.dot(), this always performs a contraction along the last dimension of each input.

Parameters:
  • a (jax.typing.ArrayLike) – array of shape (..., N)

  • b (jax.typing.ArrayLike) – array of shape (..., N)

  • precision (str | Precision | tuple[str, str] | tuple[Precision, Precision] | None) – either None (default), which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two such values indicating precision of a and b.

  • preferred_element_type (dtype | None) – either None (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

Returns:

array of shape (*a.shape[:-1], *b.shape[:-1]) containing the batched vector product of the inputs.

Return type:

Array

See also

Examples

For 1D inputs, this implements standard (non-conjugate) vector multiplication:

>>> a = jnp.array([1j, 3j, 4j])
>>> b = jnp.array([4., 2., 5.])
>>> jnp.inner(a, b)
Array(0.+30.j, dtype=complex64)

For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:

>>> a = jnp.ones((2, 3))
>>> b = jnp.ones((5, 3))
>>> jnp.inner(a, b).shape
(2, 5)