jax.lax.batch_matmul#

jax.lax.batch_matmul(lhs, rhs, precision=None)[source]#

Batch matrix multiplication.

Parameters
Return type

Array