jax.lax.collapse

Contents

jax.lax.collapse#

jax.lax.collapse(operand, start_dimension, stop_dimension=None)[source]#

Collapses dimensions of an array into a single dimension.

For example, if operand is an array with shape [2, 3, 4], collapse(operand, 0, 2).shape == [6, 4]. The elements of the collapsed dimension are laid out major-to-minor, i.e., with the lowest-numbered dimension as the slowest varying dimension.

Parameters:
  • operand (Array) – an input array.

  • start_dimension (int) – the start of the dimensions to collapse (inclusive).

  • stop_dimension (int | None) – the end of the dimensions to collapse (exclusive). Pass None to collapse all the dimensions after start.

Return type:

Array

Returns:

An array where dimensions [start_dimension, stop_dimension) have been collapsed (raveled) into a single dimension.