Public API: jax package#

Subpackages#

Just-in-time compilation (jit)#

jit(fun, *[, static_argnums, ...])

Sets up fun for just-in-time compilation with XLA.

disable_jit([disable])

Context manager that disables jit() behavior under its dynamic context.

ensure_compile_time_eval()

Context manager to ensure evaluation at trace/compile time (or error).

xla_computation(fun[, static_argnums, ...])

Creates a function that produces its XLA computation given example args.

make_jaxpr(fun[, static_argnums, axis_env, ...])

Creates a function that produces its jaxpr given example args.

eval_shape(fun, *args, **kwargs)

Compute the shape/dtype of fun without any FLOPs.

device_put(x[, device])

Transfers x to device.

device_put_replicated(x, devices)

Transfer array(s) to each specified device and form ShardedDeviceArray(s).

device_put_sharded(shards, devices)

Transfer array shards to specified devices and form ShardedDeviceArray(s).

device_get(x)

Transfer x to host.

default_backend()

Returns the platform name of the default XLA backend.

named_call(fun, *[, name])

Adds a user specified name to a function when staging out JAX computations.

named_scope(name)

A context manager that adds a user specified name to the JAX name stack.

block_until_ready(x)

Tries to call a block_until_ready method on pytree leaves.

Automatic differentiation#

grad(fun[, argnums, has_aux, holomorphic, ...])

Creates a function that evaluates the gradient of fun.

value_and_grad(fun[, argnums, has_aux, ...])

Create a function that evaluates both fun and the gradient of fun.

jacfwd(fun[, argnums, has_aux, holomorphic])

Jacobian of fun evaluated column-by-column using forward-mode AD.

jacrev(fun[, argnums, has_aux, holomorphic, ...])

Jacobian of fun evaluated row-by-row using reverse-mode AD.

hessian(fun[, argnums, has_aux, holomorphic])

Hessian of fun as a dense array.

jvp(fun, primals, tangents[, has_aux])

Computes a (forward-mode) Jacobian-vector product of fun.

linearize(fun, *primals)

Produces a linear approximation to fun using jvp() and partial eval.

linear_transpose(fun, *primals[, reduce_axes])

Transpose a function that is promised to be linear.

vjp()

Compute a (reverse-mode) vector-Jacobian product of fun.

custom_jvp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom JVP rule definition.

custom_vjp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom VJP rule definition.

closure_convert(fun, *example_args)

Closure conversion utility, for use with higher-order custom derivatives.

checkpoint(fun, *[, prevent_cse, policy, ...])

Make fun recompute internal linearization points when differentiated.

jax.Array (jax.Array)#

make_array_from_callback(shape, sharding, ...)

Returns a jax.Array via data fetched from data_callback.

make_array_from_single_device_arrays(shape, ...)

Returns a jax.Array from a sequence of jax.Arrays on a single device.

Vectorization (vmap)#

vmap(fun[, in_axes, out_axes, axis_name, ...])

Vectorizing map.

numpy.vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

Parallelization (pmap)#

pmap(fun[, axis_name, in_axes, out_axes, ...])

Parallel map with support for collective operations.

devices([backend])

Returns a list of all devices for a given backend.

local_devices([process_index, backend, host_id])

Like jax.devices(), but only returns devices local to a given process.

process_index([backend])

Returns the integer process index of this process.

device_count([backend])

Returns the total number of devices.

local_device_count([backend])

Returns the number of devices addressable by this process.

process_count([backend])

Returns the number of JAX processes associated with the backend.

Callbacks#

pure_callback(callback, result_shape_dtypes, ...)

Applies a functionally pure Python callable.

debug.callback(callback, *args[, ordered])

Calls a stageable Python callback.

Miscellaneous#

print_environment_info([return_string])

Returns a string containing local environment & JAX installation information.