jax.numpy.tensordot¶

jax.numpy.tensordot(a, b, axes=2, *, precision=None)[source]¶

Compute tensor dot product along specified axes.

LAX-backend implementation of tensordot().

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a lax.Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two lax.Precision enums indicating separate precision for each argument.

Original docstring below.

Given two tensors, a and b, and an array_like object containing two array_like objects, (a_axes, b_axes), sum the products of a’s and b’s elements (components) over the axes specified by a_axes and b_axes. The third argument can be a single non-negative integer_like scalar, N; if it is such, then the last N dimensions of a and the first N dimensions of b are summed over.

Parameters
  • a (array_like) – Tensors to “dot”.

  • b (array_like) – Tensors to “dot”.

  • axes (int or (2,) array_like) –

    • integer_like If an int N, sum over the last N axes of a and the first N axes of b in order. The sizes of the corresponding axes must match.

    • (2,) array_like Or, a list of axes to be summed over, first sequence applying to a, second to b. Both elements array_like must be of the same length.

Returns

output – The tensor dot product of the input.

Return type

ndarray