jax.lax.dot_general

Contents

jax.lax.dot_general#

jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None)[source]#

General dot product/contraction operator.

Wraps XLA’s DotGeneral operator.

The semantics of dot_general are complicated, but most users should not have to use it directly. Instead, you can use higher-level functions like jax.numpy.dot(), jax.numpy.matmul(), jax.numpy.tensordot(), jax.numpy.einsum(), and others which will construct appropriate calls to dot_general under the hood. If you really want to understand dot_general itself, we recommend reading XLA’s DotGeneral operator documentation.

Parameters:
  • lhs (ArrayLike) – an array

  • rhs (ArrayLike) – an array

  • dimension_numbers (DotDimensionNumbers) – a tuple of tuples of sequences of ints of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

  • precision (PrecisionLike) – Optional. Either None, which means the default precision for the backend, a Precision enum value (Precision.DEFAULT, Precision.HIGH or Precision.HIGHEST) or a tuple of two Precision enums indicating precision of lhs` and rhs.

  • preferred_element_type (DTypeLike | None) – Optional. Either None, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.

Returns:

An array whose first dimensions are the (shared) batch dimensions, followed by the lhs non-contracting/non-batch dimensions, and finally the rhs non-contracting/non-batch dimensions.

Return type:

Array