jax.tree_util.tree_reduceΒΆ

jax.tree_util.tree_reduce(function: Callable[[T, Any], T], tree: Any)T[source]ΒΆ
jax.tree_util.tree_reduce(function: Callable[[T, Any], T], tree: Any, initializer: T)T
Parameters
  • function (Callable[[~T, Any], ~T]) –

  • tree (Any) –

  • initializer (Any) –

Return type

~T