# jax.numpy.tensordot#

jax.numpy.tensordot(a, b, axes=2, *, precision=None, preferred_element_type=None)[source]#

Compute tensor dot product along specified axes.

LAX-backend implementation of `numpy.tensordot()`.

In addition to the original NumPy arguments listed below, also supports `precision` for extra control over matrix-multiplication precision on supported devices. `precision` may be set to `None`, which means default precision for the backend, a `Precision` enum value (`Precision.DEFAULT`, `Precision.HIGH` or `Precision.HIGHEST`) or a tuple of two `Precision` enums indicating separate precision for each argument.

Original docstring below.

Given two tensors, a and b, and an array_like object containing two array_like objects, `(a_axes, b_axes)`, sum the products of a’s and b’s elements (components) over the axes specified by `a_axes` and `b_axes`. The third argument can be a single non-negative integer_like scalar, `N`; if it is such, then the last `N` dimensions of a and the first `N` dimensions of b are summed over.

Parameters:
• a (array_like) – Tensors to “dot”.

• b (array_like) – Tensors to “dot”.

• axes (int or (2,) array_like) –

• integer_like If an int N, sum over the last N axes of a and the first N axes of b in order. The sizes of the corresponding axes must match.

• (2,) array_like Or, a list of axes to be summed over, first sequence applying to a, second to b. Both elements array_like must be of the same length.

• preferred_element_type (dtype, optional) – If specified, accumulate results and return a result of the given data type. If not specified, the accumulation dtype is determined from the type promotion rules of the input array dtypes.

• precision (PrecisionLike) –

Returns:

output – The tensor dot product of the input.

Return type:

ndarray