Public API: jax package#
Subpackages#
jax.numpy
modulejax.scipy
modulejax.lax
modulejax.random
modulejax.sharding
modulejax.debug
modulejax.dlpack
modulejax.distributed
modulejax.dtypes
modulejax.flatten_util
modulejax.image
modulejax.nn
modulejax.ops
modulejax.profiler
modulejax.stages
modulejax.tree_util
modulejax.typing
modulejax.extend
modulejax.example_libraries
modulejax.experimental
module
Configuration#
Context manager for jax_check_tracer_leaks config option. |
|
Context manager for jax_check_tracer_leaks config option. |
|
Context manager for jax_debug_nans config option. |
|
Context manager for jax_debug_infs config option. |
|
Context manager for jax_default_device config option. |
|
Context manager for jax_default_matmul_precision config option. |
|
Context manager for jax_default_prng_impl config option. |
|
Context manager for jax_enable_checks config option. |
|
Context manager for jax_enable_custom_prng config option (transient). |
|
Context manager for jax_enable_custom_vjp_by_custom_transpose config option (transient). |
|
Context manager for jax_log_compiles config option. |
|
Context manager for jax_numpy_rank_promotion config option. |
|
|
A contextmanager to control the transfer guard level for all transfers. |
Just-in-time compilation (jit
)#
|
Sets up |
|
Context manager that disables |
Context manager to ensure evaluation at trace/compile time (or error). |
|
|
Creates a function that produces its XLA computation given example args. |
|
Creates a function that produces its jaxpr given example args. |
|
Compute the shape/dtype of |
|
A container for the shape, dtype, and other static attributes of an array. |
|
Transfers |
|
Transfer array(s) to each specified device and form Array(s). |
|
Transfer array shards to specified devices and form Array(s). |
|
Transfer |
Returns the platform name of the default XLA backend. |
|
|
Adds a user specified name to a function when staging out JAX computations. |
|
A context manager that adds a user specified name to the JAX name stack. |
Tries to call a |
Automatic differentiation#
|
Creates a function that evaluates the gradient of |
|
Create a function that evaluates both |
|
Jacobian of |
|
Jacobian of |
|
Hessian of |
|
Computes a (forward-mode) Jacobian-vector product of |
|
Produces a linear approximation to |
|
Transpose a function that is promised to be linear. |
|
Compute a (reverse-mode) vector-Jacobian product of |
|
Set up a JAX-transformable function for a custom JVP rule definition. |
|
Set up a JAX-transformable function for a custom VJP rule definition. |
|
Convenience function for defining custom VJP rules (aka custom gradients). |
|
Closure conversion utility, for use with higher-order custom derivatives. |
|
Make |
jax.Array (jax.Array
)#
|
Array base class for JAX |
|
Returns a |
|
Returns a |
Vectorization (vmap
)#
|
Vectorizing map. |
|
Define a vectorized function with broadcasting. |
Parallelization (pmap
)#
|
Parallel map with support for collective operations. |
|
Returns a list of all devices for a given backend. |
|
Like |
|
Returns the integer process index of this process. |
|
Returns the total number of devices. |
|
Returns the number of devices addressable by this process. |
|
Returns the number of JAX processes associated with the backend. |
Callbacks#
|
Applies a functionally pure Python callable. |
|
Calls an impure Python callback. |
|
Calls a stageable Python callback. |
|
Prints values and works in staged out JAX functions. |
Miscellaneous#
A descriptor of an available device. |
|
|
Returns a string containing local environment & JAX installation information. |
|
Return all live arrays in the backend for platform. |
Clear all compilation and staging caches. |