jax.lax.batch_matmulΒΆ

jax.lax.batch_matmul(lhs, rhs, precision=None)[source]ΒΆ

Batch matrix multiplication.

Parameters
Return type

Any