jax.tree_util.keystr

Contents

jax.tree_util.keystr#

jax.tree_util.keystr(keys)[source]#

Helper to pretty-print a tuple of keys.

Parameters:

keys (tuple[KeyEntry, ...]) – A tuple of KeyEntry or any class that can be converted to string.

Returns:

A string that joins all string representations of the keys.

Examples

>>> import jax
>>> keys = (0, 1, 'a', 'b')
>>> jax.tree_util.keystr(keys)
'01ab'