jax.lax.dot_general
jax.lax.dot_general#
- jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None)[source]#
More general contraction operator.
Wraps XLAβs DotGeneral operator.
- Parameters
lhs (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) β an arrayrhs (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) β an arraydimension_numbers (
Tuple
[Tuple
[Sequence
[int
],Sequence
[int
]],Tuple
[Sequence
[int
],Sequence
[int
]]]) β a tuple of tuples of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))precision (
Union
[None
,str
,Precision
,Tuple
[str
,str
],Tuple
[Precision
,Precision
]]) β Optional. EitherNone
, which means the default precision for the backend, aPrecision
enum value (Precision.DEFAULT
,Precision.HIGH
orPrecision.HIGHEST
) or a tuple of twoPrecision
enums indicating precision oflhs`
andrhs
.preferred_element_type (
Union
[Any
,str
,dtype
,SupportsDType
,None
]) β Optional. EitherNone
, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype.
- Return type
Array
- Returns
An array containing the result.