jax.pure_callback#
- jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=False, **kwargs)[source]#
Calls a pure Python callback. Works under
jit()
/vmap()
/etc.For more explanation, see External Callbacks.
pure_callback
enables calling a Python function in JIT-ed JAX functions. The inputcallback
will be passed JAX arrays placed on a local CPU, and it should also return JAX arrays on CPU.The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it is safe to be called multiple times (e.g. when transformed by
vmap()
orpmap()
), or not to be called at all when e.g. the output of a jit-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows.When vmap-ed the behavior will depend on the value of the
vectorized
keyword argument. Whenvectorized
isTrue
, the callback is assumed to obeyjax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])
. Therefore, the callback will be called directly on batched inputs (where the batch axes are the leading dimensions). Additionally, the callbacks should return outputs that have corresponding leading batch axes. If not vectorizedcallback
will be mapped sequentially across the batched axis. For example, ifcallback = lambda x, y: np.matmul(x, y)
, then we are free to setvectorized=True
because thenp.matmul
function handles arbitrary leading batch dimensions.- Parameters:
callback (Callable[[...], Any]) – function to execute on the host. The callback is assumed to be a pure function (i.e. one without side-effects): if an impure function is passed, it may behave in unexpected ways, particularly under transformation. The callable will be passed PyTrees of arrays as arguments, and should return a PyTree of arrays that matches
result_shape_dtypes
.result_shape_dtypes (Any) – pytree whose leaves have
shape
anddtype
attributes, whose structure matches the expected output of the callback function at runtime.jax.ShapeDtypeStruct
is often used to define leaf values.*args (Any) – arguments to be passed to the callback function
sharding (SingleDeviceSharding | None) – optional sharding that specifies the device from which the callback should be invoked.
vectorized (bool) – boolean specifying whether the callback function can operate in a vectorized manner.
**kwargs (Any) – keyword arguments to be passed to the callback function
- Returns:
- a pytree of
jax.Array
objects whose structure matches that of result_shape_dtypes
.
- a pytree of
- Return type:
result
See also
jax.experimental.io_callback()
: callback designed for impure functions.jax.debug.callback()
: callback designed for general-purpose debugging.jax.debug.print()
: callback designed for printing.