jax.numpy.reshape

Contents

jax.numpy.reshape#

jax.numpy.reshape(a, newshape, order='C')[source]#

Gives a new shape to an array without changing its data.

LAX-backend implementation of numpy.reshape().

The JAX version of this function may in some cases return a copy rather than a view of the input.

Original docstring below.

Parameters:
  • a (array_like) – Array to be reshaped.

  • newshape (int or tuple of ints) – The new shape should be compatible with the original shape. If an integer, then the result will be a 1-D array of that length. One shape dimension can be -1. In this case, the value is inferred from the length of the array and remaining dimensions.

  • order ({'C', 'F', 'A'}, optional) – Read the elements of a using this index order, and place the elements into the reshaped array using this index order. β€˜C’ means to read / write the elements using C-like index order, with the last axis index changing fastest, back to the first axis index changing slowest. β€˜F’ means to read / write the elements using Fortran-like index order, with the first index changing fastest, and the last index changing slowest. Note that the β€˜C’ and β€˜F’ options take no account of the memory layout of the underlying array, and only refer to the order of indexing. β€˜A’ means to read / write the elements in Fortran-like index order if a is Fortran contiguous in memory, C-like order otherwise.

Returns:

reshaped_array – This will be a new view object if possible; otherwise, it will be a copy. Note there is no guarantee of the memory layout (C- or Fortran- contiguous) of the returned array.

Return type:

ndarray