jax.lax.reshape

Contents

jax.lax.reshape#

jax.lax.reshape(operand, new_sizes, dimensions=None)[source]#

Wraps XLA’s Reshape operator.

For inserting/removing dimensions of size 1, prefer using lax.squeeze / lax.expand_dims. These preserve information about axis identity that may be useful for advanced transformation rules.

Parameters:
  • operand (ArrayLike) – array to be reshaped.

  • new_sizes (Shape) – sequence of integers specifying the resulting shape. The size of the final array must match the size of the input.

  • dimensions (Sequence[int] | None) – optional sequence of integers specifying the permutation order of the input shape. If specified, the length must match operand.shape.

Returns:

reshaped array.

Return type:

out

Examples

Simple reshaping from one to two dimensions:

>>> x = jnp.arange(6)
>>> y = reshape(x, (2, 3))
>>> y
Array([[0, 1, 2],
             [3, 4, 5]], dtype=int32)

Reshaping back to one dimension:

>>> reshape(y, (6,))
Array([0, 1, 2, 3, 4, 5], dtype=int32)

Reshaping to one dimension with permutation of dimensions:

>>> reshape(y, (6,), (1, 0))
Array([0, 3, 1, 4, 2, 5], dtype=int32)