jax.numpy.tensordot#
- jax.numpy.tensordot(a, b, axes=2, *, precision=None, preferred_element_type=None)[source]#
Compute tensor dot product along specified axes.
LAX-backend implementation of
numpy.tensordot()
.In addition to the original NumPy arguments listed below, also supports
precision
for extra control over matrix-multiplication precision on supported devices.precision
may be set toNone
, which means default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twoPrecision
enums indicating separate precision for each argument.Original docstring below.
Given two tensors, a and b, and an array_like object containing two array_like objects,
(a_axes, b_axes)
, sum the products of a’s and b’s elements (components) over the axes specified bya_axes
andb_axes
. The third argument can be a single non-negative integer_like scalar,N
; if it is such, then the lastN
dimensions of a and the firstN
dimensions of b are summed over.- Parameters:
a (array_like) – Tensors to “dot”.
b (array_like) – Tensors to “dot”.
axes (int or (2,) array_like) –
integer_like If an int N, sum over the last N axes of a and the first N axes of b in order. The sizes of the corresponding axes must match.
(2,) array_like Or, a list of axes to be summed over, first sequence applying to a, second to b. Both elements array_like must be of the same length.
preferred_element_type (dtype, optional) – If specified, accumulate results and return a result of the given data type. If not specified, the accumulation dtype is determined from the type promotion rules of the input array dtypes.
precision (PrecisionLike) –
- Returns:
output – The tensor dot product of the input.
- Return type:
ndarray