Public API: jax package#

Subpackages#

Configuration#

config

check_tracer_leaks

Context manager for jax_check_tracer_leaks config option.

checking_leaks

Context manager for jax_check_tracer_leaks config option.

debug_nans

Context manager for jax_debug_nans config option.

debug_infs

Context manager for jax_debug_infs config option.

default_device

Context manager for jax_default_device config option.

default_matmul_precision

Context manager for jax_default_matmul_precision config option.

default_prng_impl

Context manager for jax_default_prng_impl config option.

enable_checks

Context manager for jax_enable_checks config option.

enable_custom_prng

Context manager for jax_enable_custom_prng config option (transient).

enable_custom_vjp_by_custom_transpose

Context manager for jax_enable_custom_vjp_by_custom_transpose config option (transient).

log_compiles

Context manager for jax_log_compiles config option.

numpy_rank_promotion

Context manager for jax_numpy_rank_promotion config option.

transfer_guard(new_val)

A contextmanager to control the transfer guard level for all transfers.

Just-in-time compilation (jit)#

jit(fun[, in_shardings, out_shardings, ...])

Sets up fun for just-in-time compilation with XLA.

disable_jit([disable])

Context manager that disables jit() behavior under its dynamic context.

ensure_compile_time_eval()

Context manager to ensure evaluation at trace/compile time (or error).

xla_computation(fun[, static_argnums, ...])

Creates a function that produces its XLA computation given example args.

make_jaxpr(fun[, static_argnums, axis_env, ...])

Creates a function that produces its jaxpr given example args.

eval_shape(fun, *args, **kwargs)

Compute the shape/dtype of fun without any FLOPs.

ShapeDtypeStruct(shape, dtype[, ...])

A container for the shape, dtype, and other static attributes of an array.

device_put(x[, device, src])

Transfers x to device.

device_put_replicated(x, devices)

Transfer array(s) to each specified device and form Array(s).

device_put_sharded(shards, devices)

Transfer array shards to specified devices and form Array(s).

device_get(x)

Transfer x to host.

default_backend()

Returns the platform name of the default XLA backend.

named_call(fun, *[, name])

Adds a user specified name to a function when staging out JAX computations.

named_scope(name)

A context manager that adds a user specified name to the JAX name stack.

block_until_ready(x)

Tries to call a block_until_ready method on pytree leaves.

Automatic differentiation#

grad(fun[, argnums, has_aux, holomorphic, ...])

Creates a function that evaluates the gradient of fun.

value_and_grad(fun[, argnums, has_aux, ...])

Create a function that evaluates both fun and the gradient of fun.

jacfwd(fun[, argnums, has_aux, holomorphic])

Jacobian of fun evaluated column-by-column using forward-mode AD.

jacrev(fun[, argnums, has_aux, holomorphic, ...])

Jacobian of fun evaluated row-by-row using reverse-mode AD.

hessian(fun[, argnums, has_aux, holomorphic])

Hessian of fun as a dense array.

jvp(fun, primals, tangents[, has_aux])

Computes a (forward-mode) Jacobian-vector product of fun.

linearize(fun, *primals[, has_aux])

Produces a linear approximation to fun using jvp() and partial eval.

linear_transpose(fun, *primals[, reduce_axes])

Transpose a function that is promised to be linear.

vjp(fun, *primals[, has_aux, reduce_axes])

Compute a (reverse-mode) vector-Jacobian product of fun.

custom_jvp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom JVP rule definition.

custom_vjp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom VJP rule definition.

custom_gradient(fun)

Convenience function for defining custom VJP rules (aka custom gradients).

closure_convert(fun, *example_args)

Closure conversion utility, for use with higher-order custom derivatives.

checkpoint(fun, *[, prevent_cse, policy, ...])

Make fun recompute internal linearization points when differentiated.

jax.Array (jax.Array)#

Array()

Array base class for JAX

make_array_from_callback(shape, sharding, ...)

Returns a jax.Array via data fetched from data_callback.

make_array_from_single_device_arrays(shape, ...)

Returns a jax.Array from a sequence of jax.Arrays each on a single device.

Vectorization (vmap)#

vmap(fun[, in_axes, out_axes, axis_name, ...])

Vectorizing map.

numpy.vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

Parallelization (pmap)#

pmap(fun[, axis_name, in_axes, out_axes, ...])

Parallel map with support for collective operations.

devices([backend])

Returns a list of all devices for a given backend.

local_devices([process_index, backend, host_id])

Like jax.devices(), but only returns devices local to a given process.

process_index([backend])

Returns the integer process index of this process.

device_count([backend])

Returns the total number of devices.

local_device_count([backend])

Returns the number of devices addressable by this process.

process_count([backend])

Returns the number of JAX processes associated with the backend.

Callbacks#

pure_callback(callback, result_shape_dtypes, ...)

Applies a functionally pure Python callable.

experimental.io_callback(callback, ...[, ...])

Calls an impure Python callback.

debug.callback(callback, *args[, ordered])

Calls a stageable Python callback.

debug.print(fmt, *args[, ordered])

Prints values and works in staged out JAX functions.

Miscellaneous#

Device

A descriptor of an available device.

print_environment_info([return_string])

Returns a string containing local environment & JAX installation information.

live_arrays([platform])

Return all live arrays in the backend for platform.

clear_caches()

Clear all compilation and staging caches.