jax.numpy.inner(a, b, *, precision=None)[source]

Inner product of two arrays.

LAX-backend implementation of inner(). In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, or any jax.lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST).

Original docstring below.

inner(a, b)

Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.


out.shape = a.shape[:-1] + b.shape[:-1]


If the last dimension of a and b has different size.

tensordot : Sum products over arbitrary axes. dot : Generalised matrix product, using second last dimension of b. einsum : Einstein summation convention.

For vectors (1-D arrays) it computes the ordinary inner-product:

np.inner(a, b) = sum(a[:]*b[:])

More generally, if ndim(a) = r > 0 and ndim(b) = s > 0:

np.inner(a, b) = np.tensordot(a, b, axes=(-1,-1))

or explicitly:

np.inner(a, b)[i0,...,ir-1,j0,...,js-1]
     = sum(a[i0,...,ir-1,:]*b[j0,...,js-1,:])

In addition a or b may be scalars, in which case:

np.inner(a,b) = a*b

Ordinary inner product for vectors:

>>> a = np.array([1,2,3])
>>> b = np.array([0,1,0])
>>> np.inner(a, b)

A multidimensional example:

>>> a = np.arange(24).reshape((2,3,4))
>>> b = np.arange(4)
>>> np.inner(a, b)
array([[ 14,  38,  62],
       [ 86, 110, 134]])

An example where b is a scalar:

>>> np.inner(np.eye(2), 7)
array([[7., 0.],
       [0., 7.]])