jax.numpy.innerΒΆ
-
jax.numpy.
inner
(a, b, *, precision=None)[source]ΒΆ Inner product of two arrays.
LAX-backend implementation of
inner()
. In addition to the original NumPy arguments listed below, also supportsprecision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, alax.Precision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twolax.Precision
enums indicating separate precision for each argument.Original docstring below.
inner(a, b)
Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.
- Parameters
b (a,) β If a and b are nonscalar, their last dimensions must match.
- Returns
out β out.shape = a.shape[:-1] + b.shape[:-1]
- Return type
- Raises
ValueError β If the last dimension of a and b has different size.
See also
tensordot()
Sum products over arbitrary axes.
dot()
Generalised matrix product, using second last dimension of b.
einsum()
Einstein summation convention.
Notes
For vectors (1-D arrays) it computes the ordinary inner-product:
np.inner(a, b) = sum(a[:]*b[:])
More generally, if ndim(a) = r > 0 and ndim(b) = s > 0:
np.inner(a, b) = np.tensordot(a, b, axes=(-1,-1))
or explicitly:
- np.inner(a, b)[i0,β¦,ir-1,j0,β¦,js-1]
= sum(a[i0,β¦,ir-1,:]*b[j0,β¦,js-1,:])
In addition a or b may be scalars, in which case:
np.inner(a,b) = a*b
Examples
Ordinary inner product for vectors:
>>> a = np.array([1,2,3]) >>> b = np.array([0,1,0]) >>> np.inner(a, b) 2
A multidimensional example:
>>> a = np.arange(24).reshape((2,3,4)) >>> b = np.arange(4) >>> np.inner(a, b) array([[ 14, 38, 62], [ 86, 110, 134]])
An example where b is a scalar:
>>> np.inner(np.eye(2), 7) array([[7., 0.], [0., 7.]])