jax.numpy.reshape#

jax.numpy.reshape(a, newshape, order='C')[source]#

Gives a new shape to an array without changing its data.

LAX-backend implementation of numpy.reshape().

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

Parameters
  • a (array_like) ā€“ Array to be reshaped.

  • newshape (int or tuple of ints) ā€“ The new shape should be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions.

  • order ({'C', 'F', 'A'}, optional) ā€“ Read the elements of a using this index order, and place the elements into the reshaped array using this index order. ā€˜Cā€™ means to read / write the elements using C-like index order, with the last axis index changing fastest, back to the first axis index changing slowest. ā€˜Fā€™ means to read / write the elements using Fortran-like index order, with the first index changing fastest, and the last index changing slowest. Note that the ā€˜Cā€™ and ā€˜Fā€™ options take no account of the memory layout of the underlying array, and only refer to the order of indexing. ā€˜Aā€™ means to read / write the elements in Fortran-like index order if a is Fortran contiguous in memory, C-like order otherwise.

Returns

reshaped_array ā€“ This will be a new view object if possible; otherwise, it will be a copy. Note there is no guarantee of the memory layout (C- or Fortran- contiguous) of the returned array.

Return type

ndarray