jax.core.ClosedJaxpr

class jax.core.ClosedJaxpr(jaxpr, consts)[source]
Parameters
__init__(jaxpr, consts)[source]
Parameters

Methods

__init__(jaxpr, consts)

param jaxpr

map_jaxpr(f)

pretty_print(*[, source_info, print_shapes])

Attributes

eqns

in_avals

literals

out_avals

jaxpr

consts