jax.lax.collapse#
- jax.lax.collapse(operand, start_dimension, stop_dimension=None)[source]#
Collapses dimensions of an array into a single dimension.
For example, if
operand
is an array with shape[2, 3, 4]
,collapse(operand, 0, 2).shape == [6, 4]
. The elements of the collapsed dimension are laid out major-to-minor, i.e., with the lowest-numbered dimension as the slowest varying dimension.- Parameters:
- Returns:
An array where dimensions
[start_dimension, stop_dimension)
have been collapsed (raveled) into a single dimension.- Return type: