jax.extend.core module# ClosedJaxpr(jaxpr, consts) Jaxpr(constvars, invars, outvars, eqns[, ...]) JaxprEqn(invars, outvars, primitive, params, ...) Literal(val, aval) Primitive(name) Token(buf) Var(suffix, aval) array_types set() -> new empty set object set(iterable) -> new set object jaxpr_as_fun primitives