jax.lax.collapseΒΆ

jax.lax.collapse(operand, start_dimension, stop_dimension)[source]ΒΆ
Parameters
  • operand (Any) –

  • start_dimension (int) –

  • stop_dimension (int) –

Return type

Any