jax.profiler module

class jax.profiler.TraceContext[source]

Bases: jaxlib.xla_extension.profiler.TraceMe

Context manager generates a trace event in the profiler.

The trace event spans the duration of the code enclosed by the context.

For example:

>>> import jax, jax.numpy as jnp
>>> x = jnp.ones((1000, 1000))
>>> with jax.profiler.TraceContext("acontext"):
...   jnp.dot(x, x.T).block_until_ready()

This will cause an “acontext” event to show up on the trace timeline if the event occurs while the process is being traced by TensorBoard.

jax.profiler.start_server(port)[source]

Starts a profiler server on port port.

Using the “TensorFlow profiler” feature in TensorBoard 2.2 or newer, you can connect to the profiler server and sample execution traces that show CPU and GPU device activity.

Returns a profiler server object. The server remains alive and listening until the server object is destroyed.

jax.profiler.trace_function(func, name=None, **kwargs)[source]

Decorator that generates a trace event for the execution of a function.

For example:

>>> import jax, jax.numpy as jnp
>>>
>>> @jax.profiler.trace_function
>>> def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>>
>>> f(jnp.ones((1000, 1000))

This will cause an “f” event to show up on the trace timeline if the function execution occurs while the process is being traced by TensorBoard.

Arguments can be passed to the decorator via functools.partial().

>>> import jax, jax.numpy as jnp
>>> from functools import partial
>>>
>>> @partial(jax.profiler.trace_function, name="event_name")
>>> def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>>
>>> f(jnp.ones((1000, 1000))