jax.numpy.reshape

Contents

jax.numpy.reshape#

jax.numpy.reshape(a, shape=None, order='C', *, newshape=Deprecated)[source]#

Return a reshaped copy of an array.

JAX implementation of numpy.reshape(), implemented in terms of jax.lax.reshape().

Parameters:
  • a (jax.typing.ArrayLike) – input array to reshape

  • shape (int | Any | Sequence[int | Any] | None) – integer or sequence of integers giving the new shape, which must match the size of the input array. If any single dimension is given size -1, it will be replaced with a value such that the output has the correct size.

  • order (str) – 'F' or 'C', specifies whether the reshape should apply column-major (fortran-style, "F") or row-major (C-style, "C") order; default is "C". JAX does not support order="A".

  • newshape (int | Any | Sequence[int | Any] | DeprecatedArg)

Returns:

reshaped copy of input array with the specified shape.

Return type:

Array

Notes

Unlike numpy.reshape(), jax.numpy.reshape() will return a copy rather than a view of the input array. However, under JIT, the compiler will optimize-away such copies when possible, so this doesn’t have performance impacts in practice.

See also

  • jax.Array.reshape(): equivalent functionality via an array method.

  • jax.numpy.ravel(): flatten an array into a 1D shape.

  • jax.numpy.squeeze(): remove one or more length-1 axes from an array’s shape.

Examples

>>> x = jnp.array([[1, 2, 3],
...                [4, 5, 6]])
>>> jnp.reshape(x, 6)
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>> jnp.reshape(x, (3, 2))
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

You can use -1 to automatically compute a shape that is consistent with the input size:

>>> jnp.reshape(x, -1)  # -1 is inferred to be 6
Array([1, 2, 3, 4, 5, 6], dtype=int32)
>>> jnp.reshape(x, (-1, 2))  # -1 is inferred to be 3
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)

The default ordering of axes in the reshape is C-style row-major ordering. To use Fortran-style column-major ordering, specify order='F':

>>> jnp.reshape(x, 6, order='F')
Array([1, 4, 2, 5, 3, 6], dtype=int32)
>>> jnp.reshape(x, (3, 2), order='F')
Array([[1, 5],
       [4, 3],
       [2, 6]], dtype=int32)

For convenience, this functionality is also available via the jax.Array.reshape() method:

>>> x.reshape(3, 2)
Array([[1, 2],
       [3, 4],
       [5, 6]], dtype=int32)