jax.extend module#

Modules for JAX extensions.

The jax.extend package provides modules for access to JAX internal machinery. See JEP #15856.

API policy#

Unlike the public API, this package offers no compatibility guarantee across releases. Breaking changes will be announced via the JAX project changelog.

jax.extend.linear_util#

StoreException

WrappedFun(f, transforms, stores, params, ...)

Represents a function f to which transforms are to be applied.

cache(call)

Memoization decorator for functions taking a WrappedFun as first argument.

merge_linear_aux(aux1, aux2)

transformation(gen, fun, *gen_static_args)

Adds one more transformation to a WrappedFun.

transformation_with_aux(gen, fun, ...[, ...])

Adds one more transformation with auxiliary output to a WrappedFun.

wrap_init(f[, params])

Wraps function f as a WrappedFun, suitable for transformation.

jax.extend.random#

define_prng_impl(*, key_shape, seed, split, ...)

param key_shape:

seed_with_impl(impl, seed)

param impl:

threefry2x32_p

threefry_2x32(keypair, count)

Apply the Threefry 2x32 hash.

threefry_prng_impl

Specifies PRNG key shape and operations.

rbg_prng_impl

Specifies PRNG key shape and operations.

unsafe_rbg_prng_impl

Specifies PRNG key shape and operations.