jax.flatten_util module

jax.flatten_util module#

List of Functions#

ravel_pytree(pytree)

Ravel (flatten) a pytree of arrays down to a 1D array.