jax.numpy.tensordot

Contents

jax.numpy.tensordot#

jax.numpy.tensordot(a, b, axes=2, *, precision=None, preferred_element_type=None)[source]#

Compute tensor dot product along specified axes.

LAX-backend implementation of numpy.tensordot().

In addition to the original NumPy arguments listed below, also supports precision for extra control over matrix-multiplication precision on supported devices. precision may be set to None, which means default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two Precision enums indicating separate precision for each argument.

Original docstring below.

Given two tensors, a and b, and an array_like object containing two array_like objects, (a_axes, b_axes), sum the products of a’s and b’s elements (components) over the axes specified by a_axes and b_axes. The third argument can be a single non-negative integer_like scalar, N; if it is such, then the last N dimensions of a and the first N dimensions of b are summed over.

Parameters:
  • a (array_like) – Tensors to “dot”.

  • b (array_like) – Tensors to “dot”.

  • axes (int or (2,) array_like) –

    • integer_like If an int N, sum over the last N axes of a and the first N axes of b in order. The sizes of the corresponding axes must match.

    • (2,) array_like Or, a list of axes to be summed over, first sequence applying to a, second to b. Both elements array_like must be of the same length.

  • preferred_element_type (dtype, optional) – If specified, accumulate results and return a result of the given data type. If not specified, the accumulation dtype is determined from the type promotion rules of the input array dtypes.

  • precision (PrecisionLike)

Returns:

output – The tensor dot product of the input.

Return type:

ndarray