jax.lax.batch_matmul

Contents

jax.lax.batch_matmul#

jax.lax.batch_matmul(lhs, rhs, precision=None)[source]#

Batch matrix multiplication.

Parameters:
  • lhs (Array)

  • rhs (Array)

  • precision (PrecisionLike | None)

Return type:

Array