Public API: jax package#


Just-in-time compilation (jit)#

jit(fun, *[, static_argnums, ...])

Sets up fun for just-in-time compilation with XLA.


Context manager that disables jit() behavior under its dynamic context.


Context manager to ensure evaluation at trace/compile time (or error).

xla_computation(fun[, static_argnums, ...])

Creates a function that produces its XLA computation given example args.

make_jaxpr(fun[, static_argnums, axis_env, ...])

Creates a function that produces its jaxpr given example args.

eval_shape(fun, *args, **kwargs)

Compute the shape/dtype of fun without any FLOPs.

device_put(x[, device])

Transfers x to device.

device_put_replicated(x, devices)

Transfer array(s) to each specified device and form ShardedDeviceArray(s).

device_put_sharded(shards, devices)

Transfer array shards to specified devices and form ShardedDeviceArray(s).


Transfer x to host.


Returns the platform name of the default XLA backend.

named_call(fun, *[, name])

Adds a user specified name to a function when staging out JAX computations.


A context manager that adds a user specified name to the JAX name stack.


Tries to call a block_until_ready method on pytree leaves.

Automatic differentiation#

grad(fun[, argnums, has_aux, holomorphic, ...])

Creates a function that evaluates the gradient of fun.

value_and_grad(fun[, argnums, has_aux, ...])

Create a function that evaluates both fun and the gradient of fun.

jacfwd(fun[, argnums, has_aux, holomorphic])

Jacobian of fun evaluated column-by-column using forward-mode AD.

jacrev(fun[, argnums, has_aux, holomorphic, ...])

Jacobian of fun evaluated row-by-row using reverse-mode AD.

hessian(fun[, argnums, has_aux, holomorphic])

Hessian of fun as a dense array.

jvp(fun, primals, tangents[, has_aux])

Computes a (forward-mode) Jacobian-vector product of fun.

linearize(fun, *primals)

Produces a linear approximation to fun using jvp() and partial eval.

linear_transpose(fun, *primals[, reduce_axes])

Transpose a function that is promised to be linear.


Compute a (reverse-mode) vector-Jacobian product of fun.

custom_jvp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom JVP rule definition.

custom_vjp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom VJP rule definition.

closure_convert(fun, *example_args)

Closure conversion utility, for use with higher-order custom derivatives.

checkpoint(fun, *[, prevent_cse, policy, ...])

Make fun recompute internal linearization points when differentiated.

jax.Array (jax.Array)#

make_array_from_callback(shape, sharding, ...)

Returns a jax.Array via data fetched from data_callback.

make_array_from_single_device_arrays(shape, ...)

Returns a jax.Array from a sequence of jax.Arrays on a single device.

Vectorization (vmap)#

vmap(fun[, in_axes, out_axes, axis_name, ...])

Vectorizing map.

numpy.vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

Parallelization (pmap)#

pmap(fun[, axis_name, in_axes, out_axes, ...])

Parallel map with support for collective operations.


Returns a list of all devices for a given backend.

local_devices([process_index, backend, host_id])

Like jax.devices(), but only returns devices local to a given process.


Returns the integer process index of this process.


Returns the total number of devices.


Returns the number of devices addressable by this process.


Returns the number of JAX processes associated with the backend.


pure_callback(callback, result_shape_dtypes, ...)

Applies a functionally pure Python callable.

debug.callback(callback, *args[, ordered])

Calls a stageable Python callback.



Returns a string containing local environment & JAX installation information.