jax.flatten_util.ravel_pytree

Contents

jax.flatten_util.ravel_pytree#

jax.flatten_util.ravel_pytree(pytree)[source]#

Ravel (flatten) a pytree of arrays down to a 1D array.

Parameters:

pytree – a pytree of arrays and scalars to ravel.

Returns:

A pair where the first element is a 1D array representing the flattened and concatenated leaf values, with dtype determined by promoting the dtypes of leaf values, and the second element is a callable for unflattening a 1D vector of the same length back to a pytree of the same structure as the input pytree. If the input pytree is empty (i.e. has no leaves) then as a convention a 1D empty array of dtype float32 is returned in the first component of the output.

For details on dtype promotion, see https://jax.readthedocs.io/en/latest/type_promotion.html.