jax.lax.dot_general#
- jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None)[source]#
General dot product/contraction operator.
Wraps XLA’s DotGeneral operator.
The semantics of
dot_general
are complicated, but most users should not have to use it directly. Instead, you can use higher-level functions likejax.numpy.dot()
,jax.numpy.matmul()
,jax.numpy.tensordot()
,jax.numpy.einsum()
, and others which will construct appropriate calls todot_general
under the hood. If you really want to understanddot_general
itself, we recommend reading XLA’s DotGeneral operator documentation.- Parameters
lhs (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – an arrayrhs (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – an arraydimension_numbers (
Tuple
[Tuple
[Sequence
[int
],Sequence
[int
]],Tuple
[Sequence
[int
],Sequence
[int
]]]) – a tuple of tuples of sequences of ints of the form((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))
precision (
Union
[None
,str
,Precision
,Tuple
[str
,str
],Tuple
[Precision
,Precision
]]) – Optional. EitherNone
, which means the default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twoPrecision
enums indicating precision oflhs`
andrhs
.preferred_element_type (
Union
[Any
,str
,dtype
,SupportsDType
,None
]) – Optional. EitherNone
, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type
- Returns
An array whose first dimensions are the (shared) batch dimensions, followed by the
lhs
non-contracting/non-batch dimensions, and finally therhs
non-contracting/non-batch dimensions.