jax.lax.batch_matmul# jax.lax.batch_matmul(lhs, rhs, precision=None)[source]# Batch matrix multiplication. Parameters: lhs (Array) – rhs (Array) – precision (Union[None, str, Precision, tuple[str, str], tuple[Precision, Precision]]) – Return type: Array