Public API: jax package
Contents
Public API: jax package#
Subpackages#
- jax.numpy package
- jax.scipy package
- JAX configuration
- jax.dlpack module
- jax.distributed module
- jax.example_libraries package
- jax.experimental package
- jax.flatten_util package
- jax.image package
- jax.lax package
- jax.nn package
- jax.ops package
- jax.profiler module
- jax.random package
- jax.tree_util package
Just-in-time compilation (jit
)#
|
Sets up |
Context manager that disables |
|
Context manager to ensure evaluation at trace/compile time (or error). |
|
|
Creates a function that produces its XLA computation given example args. |
|
Creates a function that produces its jaxpr given example args. |
|
Compute the shape/dtype of |
|
Transfers |
|
Transfer array(s) to each specified device and form ShardedDeviceArray(s). |
|
Transfer array shards to specified devices and form ShardedDeviceArray(s). |
|
Transfer |
Returns the platform name of the default XLA backend. |
|
|
Adds a user specified name to a function when staging out JAX computations. |
Tries to call a |
Automatic differentiation#
|
Creates a function that evaluates the gradient of |
|
Create a function that evaluates both |
|
Jacobian of |
|
Jacobian of |
|
Hessian of |
|
Computes a (forward-mode) Jacobian-vector product of |
|
Produces a linear approximation to |
|
Transpose a function that is promised to be linear. |
|
Compute a (reverse-mode) vector-Jacobian product of |
|
Set up a JAX-transformable function for a custom JVP rule definition. |
|
Set up a JAX-transformable function for a custom VJP rule definition. |
|
Closure conversion utility, for use with higher-order custom derivatives. |
|
Make |
Vectorization (vmap
)#
|
Vectorizing map. |
|
Define a vectorized function with broadcasting. |
Parallelization (pmap
)#
|
Parallel map with support for collective operations. |
|
Returns a list of all devices for a given backend. |
|
Like |
|
Returns the integer process index of this process. |
|
Returns the total number of devices. |
|
Returns the number of devices addressable by this process. |
|
Returns the number of JAX processes associated with the backend. |