jax.extend
module#
Modules for JAX extensions.
The jax.extend
package provides modules for access to JAX
internal machinery. See
JEP #15856.
API policy#
Unlike the public API, this package offers no compatibility guarantee across releases. Breaking changes will be announced via the JAX project changelog.
jax.extend.linear_util
#
|
Represents a function f to which transforms are to be applied. |
|
Memoization decorator for functions taking a WrappedFun as first argument. |
|
|
|
Adds one more transformation to a WrappedFun. |
|
Adds one more transformation with auxiliary output to a WrappedFun. |
|
Wraps function f as a WrappedFun, suitable for transformation. |
jax.extend.random
#
|
|
|
|
|
Apply the Threefry 2x32 hash. |
Specifies PRNG key shape and operations. |
|
Specifies PRNG key shape and operations. |
|
Specifies PRNG key shape and operations. |