jax.numpy.inner#
- jax.numpy.inner(a, b, *, precision=None, preferred_element_type=None)[source]#
Inner product of two arrays.
LAX-backend implementation of
numpy.inner()
.In addition to the original NumPy arguments listed below, also supports
precision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twoPrecision
enums indicating separate precision for each argument.Original docstring below.
Ordinary inner product of vectors for 1-D arrays (without complex conjugation), in higher dimensions a sum product over the last axes.
- Parameters:
a (array_like) – If a and b are nonscalar, their last dimensions must match.
b (array_like) – If a and b are nonscalar, their last dimensions must match.
preferred_element_type (dtype, optional) – If specified, accumulate results and return a result of the given data type. If not specified, the accumulation dtype is determined from the type promotion rules of the input array dtypes.
precision (PrecisionLike) –
- Returns:
out – If a and b are both scalars or both 1-D arrays then a scalar is returned; otherwise an array is returned.
out.shape = (*a.shape[:-1], *b.shape[:-1])
- Return type:
ndarray