jax.lax.reshape

jax.lax.reshape(operand, new_sizes, dimensions=None)[source]

Wraps XLA’s Reshape operator.

For inserting/removing dimensions of size 1, prefer using lax.squeeze / lax.expand_dims. These preserve information about axis identity that may be useful for advanced transformation rules.

Parameters
  • operand (Any) – array to be reshaped.

  • new_sizes (Sequence[Union[int, Any]]) – sequence of integers specifying the resulting shape. The size of the final array must match the size of the input.

  • dimensions (Optional[Sequence[int]]) – optional sequence of integers specifying the permutation order of the input shape. If specified, the length must match operand.shape.

Returns

reshaped array.

Return type

out

Examples

Simple reshaping from one to two dimensions:

>>> x = jnp.arange(6)
>>> y = reshape(x, (2, 3))
>>> y
DeviceArray([[0, 1, 2],
             [3, 4, 5]], dtype=int32)

Reshaping back to one dimension:

>>> reshape(y, (6,))
DeviceArray([0, 1, 2, 3, 4, 5], dtype=int32)

Reshaping to one dimension with permutation of dimensions:

>>> reshape(y, (6,), (1, 0))
DeviceArray([0, 3, 1, 4, 2, 5], dtype=int32)