- jax.pure_callback(callback, result_shape_dtypes, *args, vectorized=False, **kwargs)#
pure_callbackenables calling a Python function in JIT-ed JAX functions. The input
callbackwill be passed NumPy arrays in place of JAX arrays and should also return NumPy arrays. Execution takes place on CPU, like any Python+NumPy function.
The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it is safe to be called multiple times (e.g. when transformed by
pmap()), or not to be called at all when e.g. the output of a jit-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows.
pmap()-ed, the pure callback will be called several times (one on each axis of the map). When vmap-ed the behavior will depend on the value of the
vectorizedkeyword argument. When
True, the callback is assumed to obey
jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs]). Therefore, the callback will be called directly on batched inputs (where the batch axes are the leading dimensions). Additionally, the callbacks should return outputs that have corresponding leading batch axes. If not vectorized
callbackwill be mapped sequentially across the batched axis. For example, if
callback = lambda x, y: np.matmul(x, y), then we are free to set
np.matmulfunction handles arbitrary leading batch dimensions.
Any) – A PyTree with leaves that are objects with
dtypeattributes which represent to the shapes and dtypes of the value of
*args – The positional arguments to the callback. Must be PyTrees of JAX types.
bool) – A boolean that indicates whether or not
callbackis vectorized, meaning it can handle arrays with additional leading dimensions. If
vectorizedis True, when the callback is mapped via jax.vmap, it will be called directly on inputs with leading batch dimensions instead of executing
callbackon each mapped input individually. The callback should also return outputs batched across the leading axis. By default,
**kwargs – The keyword arguments to the callback. Must be PyTrees of JAX types.
The value of