Public API: jax
package#
Subpackages#
jax.numpy
modulejax.scipy
modulejax.lax
modulejax.random
modulejax.sharding
modulejax.debug
modulejax.dlpack
modulejax.distributed
modulejax.dtypes
modulejax.flatten_util
modulejax.image
modulejax.nn
modulejax.ops
modulejax.profiler
modulejax.stages
modulejax.tree
modulejax.tree_util
modulejax.typing
modulejax.export
modulejax.extend
modulejax.example_libraries
modulejax.experimental
module
Configuration#
Context manager for jax_check_tracer_leaks config option. |
|
Context manager for jax_check_tracer_leaks config option. |
|
Context manager for jax_debug_nans config option. |
|
Context manager for jax_debug_infs config option. |
|
Context manager for jax_default_device config option. |
|
Context manager for jax_default_matmul_precision config option. |
|
Context manager for jax_default_prng_impl config option. |
|
Context manager for jax_enable_checks config option. |
|
Context manager for jax_enable_custom_prng config option (transient). |
|
Context manager for jax_enable_custom_vjp_by_custom_transpose config option (transient). |
|
Context manager for jax_log_compiles config option. |
|
Context manager for jax_numpy_rank_promotion config option. |
|
|
A contextmanager to control the transfer guard level for all transfers. |
Just-in-time compilation (jit
)#
|
Sets up |
|
Context manager that disables |
Context manager to ensure evaluation at trace/compile time (or error). |
|
|
Creates a function that produces its jaxpr given example args. |
|
Compute the shape/dtype of |
|
A container for the shape, dtype, and other static attributes of an array. |
|
Transfers |
|
Transfer array(s) to each specified device and form Array(s). |
|
Transfer array shards to specified devices and form Array(s). |
|
Transfer |
Returns the platform name of the default XLA backend. |
|
|
Adds a user specified name to a function when staging out JAX computations. |
|
A context manager that adds a user specified name to the JAX name stack. |
Tries to call a |
Automatic differentiation#
|
Creates a function that evaluates the gradient of |
|
Create a function that evaluates both |
|
Alias of |
|
Jacobian of |
|
Jacobian of |
|
Hessian of |
|
Computes a (forward-mode) Jacobian-vector product of |
Produces a linear approximation to |
|
|
Transpose a function that is promised to be linear. |
|
Compute a (reverse-mode) vector-Jacobian product of |
|
Convenience function for defining custom VJP rules (aka custom gradients). |
|
Closure conversion utility, for use with higher-order custom derivatives. |
|
Make |
custom_jvp
#
|
Set up a JAX-transformable function for a custom JVP rule definition. |
|
Define a custom JVP rule for the function represented by this instance. |
|
Convenience wrapper for defining JVPs for each argument separately. |
custom_vjp
#
|
Set up a JAX-transformable function for a custom VJP rule definition. |
|
Define a custom VJP rule for the function represented by this instance. |
jax.Array (jax.Array
)#
|
Array base class for JAX |
|
Returns a |
|
Returns a |
|
Creates distributed tensor using the data available in process. |
Array properties and methods#
List of addressable shards. |
|
|
Test whether all array elements along a given axis evaluate to True. |
|
Test whether any array elements along a given axis evaluate to True. |
|
Return the index of the maximum value. |
|
Return the index of the minimum value. |
|
Return the indices that partially sort the array. |
|
Return the indices that sort the array. |
|
Copy the array and cast to a specified dtype. |
Helper property for index update functionality. |
|
|
Construct an array choosing from elements of multiple arrays. |
|
Return an array whose values are limited to a specified range. |
|
Return selected slices of this array along given axis. |
Return the complex conjugate of the array. |
|
Return the complex conjugate of the array. |
|
Return a copy of the array. |
|
Copies an |
|
|
Return the cumulative product of the array. |
|
Return the cumulative sum of the array. |
Array API-compatible device attribute. |
|
|
Return the specified diagonal from the array. |
|
Compute the dot product of two arrays. |
The data type ( |
|
Use |
|
|
Flatten array into a 1-dimensional shape. |
List of global shards. |
|
Return the imaginary part of the array. |
|
Is this Array fully addressable? |
|
Is this Array fully replicated? |
|
|
Copy an element of an array to a standard Python scalar and return it. |
Length of one array element in bytes. |
|
|
Return the maximum of array elements along a given axis. |
|
Return the mean of array elements along a given axis. |
|
Return the minimum of array elements along a given axis. |
Total bytes consumed by the elements of the array. |
|
The number of dimensions in the array. |
|
|
Return indices of nonzero elements of an array. |
|
Return product of the array elements over a given axis. |
|
Return the peak-to-peak range along a given axis. |
|
Flatten array into a 1-dimensional shape. |
Return the real part of the array. |
|
|
Construct an array from repeated elements. |
|
Returns an array containing the same data with a new shape. |
|
Round array elements to a given decimal. |
|
Perform a binary search within a sorted array. |
The shape of the array. |
|
The sharding for the array. |
|
The total number of elements in the array. |
|
|
Return a sorted copy of an array. |
|
Remove one or more length-1 axes from array. |
|
Compute the standard deviation along a given axis. |
|
Sum of the elements of the array over a given axis. |
|
Swap two axes of an array. |
|
Take elements from an array. |
|
Return a copy of the array on the specified device |
|
Return the sum along the diagonal. |
|
Returns a copy of the array with axes transposed. |
|
Compute the variance along a given axis. |
|
Return a bitwise copy of the array, viewed as a new dtype. |
Compute the all-axis array transpose. |
|
Compute the (batched) matrix transpose. |
Vectorization (vmap
)#
|
Vectorizing map. |
|
Define a vectorized function with broadcasting. |
Parallelization (pmap
)#
|
Parallel map with support for collective operations. |
|
Returns a list of all devices for a given backend. |
|
Like |
|
Returns the integer process index of this process. |
|
Returns the total number of devices. |
|
Returns the number of devices addressable by this process. |
|
Returns the number of JAX processes associated with the backend. |
|
Returns the list of all JAX process indices associated with the backend. |
Callbacks#
|
Calls a pure Python callback. |
|
Calls an impure Python callback. |
|
Calls a stageable Python callback. |
|
Prints values and works in staged out JAX functions. |
Miscellaneous#
A descriptor of an available device. |
|
|
Returns a string containing local environment & JAX installation information. |
|
Return all live arrays in the backend for platform. |
Clear all compilation and staging caches. |