jax.pure_callback#
- jax.pure_callback(callback, result_shape_dtypes, *args, vectorized=False, **kwargs)[source]#
Applies a functionally pure Python callable. Works under
jit()
/pmap()
/etc.pure_callback
enables calling a Python function in JIT-ed JAX functions. The inputcallback
will be passed NumPy arrays in place of JAX arrays and should also return NumPy arrays. Execution takes place on CPU, like any Python+NumPy function.The callback is treated as functionally pure, meaning it has no side-effects and its output value depends only on its argument values. As a consequence, it is safe to be called multiple times (e.g. when transformed by
vmap()
orpmap()
), or not to be called at all when e.g. the output of a jit-decorated function has no data dependence on its value. Pure callbacks may also be reordered if data-dependence allows.When
pmap()
-ed, the pure callback will be called several times (one on each axis of the map). When vmap-ed the behavior will depend on the value of thevectorized
keyword argument. Whenvectorized
isTrue
, the callback is assumed to obeyjax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])
. Therefore, the callback will be called directly on batched inputs (where the batch axes are the leading dimensions). Additionally, the callbacks should return outputs that have corresponding leading batch axes. If not vectorizedcallback
will be mapped sequentially across the batched axis. For example, ifcallback = lambda x, y: np.matmul(x, y)
, then we are free to setvectorized=True
because thenp.matmul
function handles arbitrary leading batch dimensions.- Parameters
callback (
Callable
[...
,Any
]) – A Python callable. The callable will be passed PyTrees of NumPy arrays as arguments, and should return a PyTree of NumPy arrays that matchesresult_shape_dtypes
.result_shape_dtypes (
Any
) – A PyTree with leaves that are objects withshape
anddtype
attributes which represent to the shapes and dtypes of the value ofcallback
applied toargs
andkwargs
.*args (
Any
) – The positional arguments to the callback. Must be PyTrees of JAX types.vectorized (
bool
) – A boolean that indicates whether or notcallback
is vectorized, meaning it can handle arrays with additional leading dimensions. Ifvectorized
is True, when the callback is mapped via jax.vmap, it will be called directly on inputs with leading batch dimensions instead of executingcallback
on each mapped input individually. The callback should also return outputs batched across the leading axis. By default,vectorized
isFalse
.**kwargs (
Any
) – The keyword arguments to the callback. Must be PyTrees of JAX types.
- Returns
The value of
callback(*args, **kwargs)
.