jax.numpy.inner(a, b, *, precision=None)[source]ΒΆ

Inner product of two arrays.

LAX-backend implementation of inner(). In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two lax.Precision enums indicating separate precision for each argument.

Original docstring below.

inner(a, b)

Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.


b (a,) – If a and b are nonscalar, their last dimensions must match.


out – out.shape = a.shape[:-1] + b.shape[:-1]

Return type



ValueError – If the last dimension of a and b has different size.

See also


Sum products over arbitrary axes.


Generalised matrix product, using second last dimension of b.


Einstein summation convention.


For vectors (1-D arrays) it computes the ordinary inner-product:

np.inner(a, b) = sum(a[:]*b[:])

More generally, if ndim(a) = r > 0 and ndim(b) = s > 0:

np.inner(a, b) = np.tensordot(a, b, axes=(-1,-1))

or explicitly:

np.inner(a, b)[i0,…,ir-1,j0,…,js-1]

= sum(a[i0,…,ir-1,:]*b[j0,…,js-1,:])

In addition a or b may be scalars, in which case:

np.inner(a,b) = a*b


Ordinary inner product for vectors:

>>> a = np.array([1,2,3])
>>> b = np.array([0,1,0])
>>> np.inner(a, b)

A multidimensional example:

>>> a = np.arange(24).reshape((2,3,4))
>>> b = np.arange(4)
>>> np.inner(a, b)
array([[ 14,  38,  62],
       [ 86, 110, 134]])

An example where b is a scalar:

>>> np.inner(np.eye(2), 7)
array([[7., 0.],
       [0., 7.]])