jax.lax.dot_general#
- jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None, out_sharding=None)[source]#
General dot product/contraction operator.
Wraps XLA’s DotGeneral operator.
The semantics of
dot_general
are complicated, but most users should not have to use it directly. Instead, you can use higher-level functions likejax.numpy.dot()
,jax.numpy.matmul()
,jax.numpy.tensordot()
,jax.numpy.einsum()
, and others which will construct appropriate calls todot_general
under the hood. If you really want to understanddot_general
itself, we recommend reading XLA’s DotGeneral operator documentation.- Parameters:
lhs (ArrayLike) – an array
rhs (ArrayLike) – an array
dimension_numbers (DotDimensionNumbers) – a tuple of tuples of sequences of ints of the form
((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
precision (PrecisionLike | None) –
Optional. This parameter controls the numerics of the computation, and it can be one of the following:
None
, which means the default precision for the current backend,a
Precision
enum value or a tuple of twoPrecision
enums indicating precision oflhs`
andrhs
, ora
DotAlgorithm
or aDotAlgorithmPreset
indicating the algorithm that must be used to accumulate the dot product.
preferred_element_type (DTypeLike | None | None) – Optional. This parameter controls the data type output by the dot product. By default, the output element type of this operation will match the
lhs
andrhs
input element types under the usual type promotion rules. Settingpreferred_element_type
to a specificdtype
will mean that the operation returns that element type. Whenprecision
is not aDotAlgorithm
orDotAlgorithmPreset
,preferred_element_type
provides a hint to the compiler to accumulate the dot product using this data type.
- Returns:
An array whose first dimensions are the (shared) batch dimensions, followed by the
lhs
non-contracting/non-batch dimensions, and finally therhs
non-contracting/non-batch dimensions.- Return type: