jax.lax.dot_general#

jax.lax.dot_general(lhs, rhs, dimension_numbers, precision=None, preferred_element_type=None)[source]#

General dot product/contraction operator.

Wraps XLA’s DotGeneral operator.

The semantics of dot_general are complicated, but most users should not have to use it directly. Instead, you can use higher-level functions like jax.numpy.dot(), jax.numpy.matmul(), jax.numpy.tensordot(), jax.numpy.einsum(), and others which will construct appropriate calls to dot_general under the hood. If you really want to understand dot_general itself, we recommend reading XLA’s DotGeneral operator documentation.

Parameters
Return type

Array

Returns

An array whose first dimensions are the (shared) batch dimensions, followed by the lhs non-contracting/non-batch dimensions, and finally the rhs non-contracting/non-batch dimensions.