jax.core.Jaxpr

jax.core.Jaxpr#

class jax.core.Jaxpr(constvars, invars, outvars, eqns, effects=frozenset({}), debug_info=None)[source]#
Parameters:
  • constvars (Sequence[Var]) –

  • invars (Sequence[Var]) –

  • outvars (Sequence[Atom]) –

  • eqns (Sequence[JaxprEqn]) –

  • effects (Effects) –

  • debug_info (JaxprDebugInfo | None) –

__init__(constvars, invars, outvars, eqns, effects=frozenset({}), debug_info=None)[source]#
Parameters:
  • constvars (Sequence[Var]) – list of variables introduced for constants. Array constants are replaced with such variables while scalar constants are kept inline.

  • invars (Sequence[Var]) – list of input variables. Together, constvars and invars are the inputs to the Jaxpr.

  • outvars (Sequence[Atom]) – list of output atoms.

  • eqns (Sequence[JaxprEqn]) – list of equations.

  • effects (Effects) – set of effects. The effects on a jaxpr are a superset of the union of the effects for each equation.

  • debug_info (JaxprDebugInfo | None) – optional JaxprDebugInfo.

Methods

__init__(constvars, invars, outvars, eqns[, ...])

type constvars:

Sequence[Var]

pretty_print(*[, source_info, print_shapes, ...])

param print_effects:

replace(*[, constvars, invars, outvars, ...])

Attributes

constvars

debug_info

effects

eqns

invars

outvars