jax.core.Jaxpr#
- class jax.core.Jaxpr(constvars, invars, outvars, eqns, effects=frozenset({}), debug_info=None)[source]#
- Parameters:
constvars (Sequence[Var])
invars (Sequence[Var])
outvars (Sequence[Atom])
eqns (Sequence[JaxprEqn])
effects (Effects)
debug_info (JaxprDebugInfo | None)
- __init__(constvars, invars, outvars, eqns, effects=frozenset({}), debug_info=None)[source]#
- Parameters:
constvars (Sequence[Var]) – list of variables introduced for constants. Array constants are replaced with such variables while scalar constants are kept inline.
invars (Sequence[Var]) – list of input variables. Together, constvars and invars are the inputs to the Jaxpr.
outvars (Sequence[Atom]) – list of output atoms.
eqns (Sequence[JaxprEqn]) – list of equations.
effects (Effects) – set of effects. The effects on a jaxpr are a superset of the union of the effects for each equation.
debug_info (JaxprDebugInfo | None | None) – optional JaxprDebugInfo.
Methods
__init__
(constvars, invars, outvars, eqns[, ...])pretty_print
(*[, source_info, print_shapes, ...])replace
(**kwargs)Attributes
constvars
debug_info
effects
eqns
invars
outvars