- jax.lax.collapse(operand, start_dimension, stop_dimension=None)#
Collapses dimensions of an array into a single dimension.
For example, if
operandis an array with shape
[2, 3, 4],
collapse(operand, 0, 2).shape == [6, 4]. The elements of the collapsed dimension are laid out major-to-minor, i.e., with the lowest-numbered dimension as the slowest varying dimension.
- Return type:
An array where dimensions
[start_dimension, stop_dimension)have been collapsed (raveled) into a single dimension.