class jax.tree_util.Partial[source]

A version of functools.partial that works in pytrees.

Use it for partial function evaluation in a way that is compatible with JAX’s transformations, e.g., Partial(func, *args, **kwargs).

(You need to explicitly opt-in to this behavior because we didn’t want to give functools.partial different semantics than normal function closures.)

For example, here is a basic usage of Partial in a manner similar to functools.partial:

>>> import jax.numpy as jnp
>>> add_one = Partial(jnp.add, 1)
>>> add_one(2)
DeviceArray(3, dtype=int32)

Pytree compatibility means that the resulting partial function can be passed as an argument within transformed JAX functions, which is not possible with a standard functools.partial function:

>>> from jax import jit
>>> @jit
... def call_func(f, *args):
...   return f(*args)
>>> call_func(add_one, 2)
DeviceArray(3, dtype=int32)

Passing zero arguments to Partial effectively wraps the original function, making it a valid argument in JAX transformed functions:

>>> call_func(Partial(jnp.add), 1, 2)
DeviceArray(3, dtype=int32)

Had we passed jnp.add to call_func directly, it would have resulted in a TypeError.

Note that if the result of Partial is used in the context where the value is traced, it results in all bound arguments being traced when passed to the partially-evaluated function:

>>> print_zero = Partial(print, 0)
>>> print_zero()
>>> call_func(print_zero)
Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>

Initialize self. See help(type(self)) for accurate signature.



Initialize self.



tuple of arguments to future partial calls


function object to use in future partial calls


dictionary of keyword arguments to future partial calls