jax.lax.collapseΒΆ

jax.lax.collapse(operand, start_dimension, stop_dimension)[source]ΒΆ

Collapses dimensions of an array into a single dimension.

For example, if operand is an array with shape [2, 3, 4], collapse(operand, 0, 2).shape == [6, 4]. The elements of the collapsed dimension are laid out major-to-minor, i.e., with the lowest-numbered dimension as the slowest varying dimension.

Parameters
  • operand (Any) – an input array.

  • start_dimension (int) – the start of the dimensions to collapse (inclusive).

  • stop_dimension (int) – the end of the dimensions to collapse (exclusive).

Return type

Any

Returns

An array where dimensions [start_dimension, stop_dimension) have been collapsed (raveled) into a single dimension.