jax.pure_callback

Contents

jax.pure_callback#

jax.pure_callback(callback, result_shape_dtypes, *args, sharding=None, vectorized=False, **kwargs)[source]#

Calls a pure Python callback. Works under jit()/vmap()/etc.

For more explanation, see External Callbacks.

pure_callback enables calling a Python function in JIT-ed JAX functions. The input callback will be passed JAX arrays placed on a local CPU, and it should also return JAX arrays on CPU.

The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it is safe to be called multiple times (e.g. when transformed by vmap() or pmap()), or not to be called at all when e.g. the output of a jit-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows.

When vmap-ed the behavior will depend on the value of the vectorized keyword argument. When vectorized is True, the callback is assumed to obey jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs]). Therefore, the callback will be called directly on batched inputs (where the batch axes are the leading dimensions). Additionally, the callbacks should return outputs that have corresponding leading batch axes. If not vectorized callback will be mapped sequentially across the batched axis. For example, if callback = lambda x, y: np.matmul(x, y), then we are free to set vectorized=True because the np.matmul function handles arbitrary leading batch dimensions.

Parameters:
  • callback (Callable[[...], Any]) – function to execute on the host. The callback is assumed to be a pure function (i.e. one without side-effects): if an impure function is passed, it may behave in unexpected ways, particularly under transformation. The callable will be passed PyTrees of arrays as arguments, and should return a PyTree of arrays that matches result_shape_dtypes.

  • result_shape_dtypes (Any) – pytree whose leaves have shape and dtype attributes, whose structure matches the expected output of the callback function at runtime. jax.ShapeDtypeStruct is often used to define leaf values.

  • *args (Any) – arguments to be passed to the callback function

  • sharding (SingleDeviceSharding | None) – optional sharding that specifies the device from which the callback should be invoked.

  • vectorized (bool) – boolean specifying whether the callback function can operate in a vectorized manner.

  • **kwargs (Any) – keyword arguments to be passed to the callback function

Returns:

a pytree of jax.Array objects whose structure matches that of

result_shape_dtypes.

Return type:

result

See also