jax.profiler.trace_functionΒΆ

jax.profiler.trace_function(func, name=None, **kwargs)[source]ΒΆ

Decorator that generates a trace event for the execution of a function.

For example:

>>> import jax, jax.numpy as jnp
>>>
>>> @jax.profiler.trace_function
>>> def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>>
>>> f(jnp.ones((1000, 1000))

This will cause an β€œf” event to show up on the trace timeline if the function execution occurs while the process is being traced by TensorBoard.

Arguments can be passed to the decorator via functools.partial().

>>> import jax, jax.numpy as jnp
>>> from functools import partial
>>>
>>> @partial(jax.profiler.trace_function, name="event_name")
>>> def f(x):
...   return jnp.dot(x, x.T).block_until_ready()
>>>
>>> f(jnp.ones((1000, 1000))
Parameters