A context manager that adds a user specified name to the JAX name stack.
When staging out computations for just-in-time compilation to XLA (or other backends such as TensorFlow) JAX does not, by default, preserve the names (or other source metadata) of Python functions it encounters. This can make debugging the staged out (and/or compiled) representation of your program complicated because there is limited context information for each operation being executed.
named_scopetells JAX to stage the given function with additional annotations on the underlying operations. JAX internally keeps track of these annotations in a name stack. When the staged out program is compiled with XLA these annotations are preserved and show up in debugging utilities like the TensorFlow Profiler in TensorBoard. Names are also preserved when staging out JAX programs to TensorFlow using
str) – The prefix to use to name all operations created within the name scope.
None, but enters a context in which name will be appended to the active name stack.
named_scopecan be used as a context manager inside compiled functions:
>>> import jax >>> >>> @jax.jit ... def layer(w, x): ... with jax.named_scope("dot_product"): ... logits = w.dot(x) ... with jax.named_scope("activation"): ... return jax.nn.relu(logits)
It can also be used as a decorator:
>>> @jax.jit ... @jax.named_scope("layer") ... def layer(w, x): ... logits = w.dot(x) ... return jax.nn.relu(logits)