jax.lax.dot_general#

jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None, out_sharding=None)[source]#

General dot product/contraction operator.

Wraps XLA’s DotGeneral operator.

The semantics of dot_general are complicated, but most users should not have to use it directly. Instead, you can use higher-level functions like jax.numpy.dot(), jax.numpy.matmul(), jax.numpy.tensordot(), jax.numpy.einsum(), and others which will construct appropriate calls to dot_general under the hood. If you really want to understand dot_general itself, we recommend reading XLA’s DotGeneral operator documentation.

Parameters:
  • lhs (ArrayLike) – an array

  • rhs (ArrayLike) – an array

  • dimension_numbers (DotDimensionNumbers) – a tuple of tuples of sequences of ints of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))

  • precision (PrecisionLike | None) –

    Optional. This parameter controls the numerics of the computation, and it can be one of the following:

    • None, which means the default precision for the current backend,

    • a Precision enum value or a tuple of two Precision enums indicating precision of lhs` and rhs, or

    • a DotAlgorithm or a DotAlgorithmPreset indicating the algorithm that must be used to accumulate the dot product.

  • preferred_element_type (DTypeLike | None | None) – Optional. This parameter controls the data type output by the dot product. By default, the output element type of this operation will match the lhs and rhs input element types under the usual type promotion rules. Setting preferred_element_type to a specific dtype will mean that the operation returns that element type. When precision is not a DotAlgorithm or DotAlgorithmPreset, preferred_element_type provides a hint to the compiler to accumulate the dot product using this data type.

Returns:

An array whose first dimensions are the (shared) batch dimensions, followed by the lhs non-contracting/non-batch dimensions, and finally the rhs non-contracting/non-batch dimensions.

Return type:

Array