# jax.numpy.inner#

jax.numpy.inner(a, b, *, precision=None, preferred_element_type=None)[source]#

Compute the inner product of two arrays.

JAX implementation of `numpy.inner()`.

Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a contraction along the last dimension of each input.

Parameters:
• a (jax.typing.ArrayLike) â€“ array of shape `(..., N)`

• b (jax.typing.ArrayLike) â€“ array of shape `(..., N)`

• precision (str | Precision | tuple[str, str] | tuple[Precision, Precision] | None) â€“ either `None` (default), which means the default precision for the backend, a `Precision` enum value (`Precision.DEFAULT`, `Precision.HIGH` or `Precision.HIGHEST`) or a tuple of two such values indicating precision of `a` and `b`.

• preferred_element_type (dtype | None) â€“ either `None` (default), which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

Returns:

array of shape `(*a.shape[:-1], *b.shape[:-1])` containing the batched vector product of the inputs.

Return type:

Array

Examples

For 1D inputs, this implements standard (non-conjugate) vector multiplication:

```>>> a = jnp.array([1j, 3j, 4j])
>>> b = jnp.array([4., 2., 5.])
>>> jnp.inner(a, b)
Array(0.+30.j, dtype=complex64)
```

For multi-dimensional inputs, batch dimensions are stacked rather than broadcast:

```>>> a = jnp.ones((2, 3))
>>> b = jnp.ones((5, 3))
>>> jnp.inner(a, b).shape
(2, 5)
```