Public API: jax package#

Subpackages#

Configuration#

config

check_tracer_leaks

Context manager for jax_check_tracer_leaks config option.

checking_leaks

Context manager for jax_check_tracer_leaks config option.

debug_nans

Context manager for jax_debug_nans config option.

debug_infs

Context manager for jax_debug_infs config option.

default_device

Context manager for jax_default_device config option.

default_matmul_precision

Context manager for jax_default_matmul_precision config option.

default_prng_impl

Context manager for jax_default_prng_impl config option.

enable_checks

Context manager for jax_enable_checks config option.

enable_custom_prng

Context manager for jax_enable_custom_prng config option (transient).

enable_custom_vjp_by_custom_transpose

Context manager for jax_enable_custom_vjp_by_custom_transpose config option (transient).

log_compiles

Context manager for jax_log_compiles config option.

numpy_rank_promotion

Context manager for jax_numpy_rank_promotion config option.

transfer_guard(new_val)

A contextmanager to control the transfer guard level for all transfers.

Just-in-time compilation (jit)#

jit(fun[, in_shardings, out_shardings, ...])

Sets up fun for just-in-time compilation with XLA.

disable_jit([disable])

Context manager that disables jit() behavior under its dynamic context.

ensure_compile_time_eval()

Context manager to ensure evaluation at trace/compile time (or error).

make_jaxpr([axis_env, return_shape, ...])

Creates a function that produces its jaxpr given example args.

eval_shape(fun, *args, **kwargs)

Compute the shape/dtype of fun without any FLOPs.

ShapeDtypeStruct(shape, dtype, *[, ...])

A container for the shape, dtype, and other static attributes of an array.

device_put(x[, device, src, donate, may_alias])

Transfers x to device.

device_put_replicated(x, devices)

Transfer array(s) to each specified device and form Array(s).

device_put_sharded(shards, devices)

Transfer array shards to specified devices and form Array(s).

device_get(x)

Transfer x to host.

default_backend()

Returns the platform name of the default XLA backend.

named_call(fun, *[, name])

Adds a user specified name to a function when staging out JAX computations.

named_scope(name)

A context manager that adds a user specified name to the JAX name stack.

block_until_ready(x)

Tries to call a block_until_ready method on pytree leaves.

Automatic differentiation#

grad(fun[, argnums, has_aux, holomorphic, ...])

Creates a function that evaluates the gradient of fun.

value_and_grad(fun[, argnums, has_aux, ...])

Create a function that evaluates both fun and the gradient of fun.

jacobian(fun[, argnums, has_aux, ...])

Alias of jax.jacrev().

jacfwd(fun[, argnums, has_aux, holomorphic])

Jacobian of fun evaluated column-by-column using forward-mode AD.

jacrev(fun[, argnums, has_aux, holomorphic, ...])

Jacobian of fun evaluated row-by-row using reverse-mode AD.

hessian(fun[, argnums, has_aux, holomorphic])

Hessian of fun as a dense array.

jvp(fun, primals, tangents[, has_aux])

Computes a (forward-mode) Jacobian-vector product of fun.

linearize()

Produces a linear approximation to fun using jvp() and partial eval.

linear_transpose(fun, *primals[, reduce_axes])

Transpose a function that is promised to be linear.

vjp() ))

Compute a (reverse-mode) vector-Jacobian product of fun.

custom_gradient(fun)

Convenience function for defining custom VJP rules (aka custom gradients).

closure_convert(fun, *example_args)

Closure conversion utility, for use with higher-order custom derivatives.

checkpoint(fun, *[, prevent_cse, policy, ...])

Make fun recompute internal linearization points when differentiated.

custom_jvp#

custom_jvp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom JVP rule definition.

custom_jvp.defjvp(jvp[, symbolic_zeros])

Define a custom JVP rule for the function represented by this instance.

custom_jvp.defjvps(*jvps)

Convenience wrapper for defining JVPs for each argument separately.

custom_vjp#

custom_vjp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom VJP rule definition.

custom_vjp.defvjp(fwd, bwd[, ...])

Define a custom VJP rule for the function represented by this instance.

jax.Array (jax.Array)#

Array()

Array base class for JAX

make_array_from_callback(shape, sharding, ...)

Returns a jax.Array via data fetched from data_callback.

make_array_from_single_device_arrays(shape, ...)

Returns a jax.Array from a sequence of jax.Arrays each on a single device.

make_array_from_process_local_data(sharding, ...)

Creates distributed tensor using the data available in process.

Array properties and methods#

Array.addressable_shards

List of addressable shards.

Array.all([axis, out, keepdims, where])

Test whether all array elements along a given axis evaluate to True.

Array.any([axis, out, keepdims, where])

Test whether any array elements along a given axis evaluate to True.

Array.argmax([axis, out, keepdims])

Return the index of the maximum value.

Array.argmin([axis, out, keepdims])

Return the index of the minimum value.

Array.argpartition(kth[, axis])

Return the indices that partially sort the array.

Array.argsort([axis, kind, order, stable, ...])

Return the indices that sort the array.

Array.astype(dtype[, copy, device])

Copy the array and cast to a specified dtype.

Array.at

Helper property for index update functionality.

Array.choose(choices[, out, mode])

Construct an array choosing from elements of multiple arrays.

Array.clip([min, max])

Return an array whose values are limited to a specified range.

Array.compress(condition[, axis, out, size, ...])

Return selected slices of this array along given axis.

Array.conj()

Return the complex conjugate of the array.

Array.conjugate()

Return the complex conjugate of the array.

Array.copy()

Return a copy of the array.

Array.copy_to_host_async()

Copies an Array to the host asynchronously.

Array.cumprod([axis, dtype, out])

Return the cumulative product of the array.

Array.cumsum([axis, dtype, out])

Return the cumulative sum of the array.

Array.device

Array API-compatible device attribute.

Array.diagonal([offset, axis1, axis2])

Return the specified diagonal from the array.

Array.dot(b, *[, precision, ...])

Compute the dot product of two arrays.

Array.dtype

The data type (numpy.dtype) of the array.

Array.flat

Use flatten() instead.

Array.flatten([order])

Flatten array into a 1-dimensional shape.

Array.global_shards

List of global shards.

Array.imag

Return the imaginary part of the array.

Array.is_fully_addressable

Is this Array fully addressable?

Array.is_fully_replicated

Is this Array fully replicated?

Array.item(*args)

Copy an element of an array to a standard Python scalar and return it.

Array.itemsize

Length of one array element in bytes.

Array.max([axis, out, keepdims, initial, where])

Return the maximum of array elements along a given axis.

Array.mean([axis, dtype, out, keepdims, where])

Return the mean of array elements along a given axis.

Array.min([axis, out, keepdims, initial, where])

Return the minimum of array elements along a given axis.

Array.nbytes

Total bytes consumed by the elements of the array.

Array.ndim

The number of dimensions in the array.

Array.nonzero(*[, fill_value, size])

Return indices of nonzero elements of an array.

Array.prod([axis, dtype, out, keepdims, ...])

Return product of the array elements over a given axis.

Array.ptp([axis, out, keepdims])

Return the peak-to-peak range along a given axis.

Array.ravel([order])

Flatten array into a 1-dimensional shape.

Array.real

Return the real part of the array.

Array.repeat(repeats[, axis, ...])

Construct an array from repeated elements.

Array.reshape(*args[, order])

Returns an array containing the same data with a new shape.

Array.round([decimals, out])

Round array elements to a given decimal.

Array.searchsorted(v[, side, sorter, method])

Perform a binary search within a sorted array.

Array.shape

The shape of the array.

Array.sharding

The sharding for the array.

Array.size

The total number of elements in the array.

Array.sort([axis, kind, order, stable, ...])

Return a sorted copy of an array.

Array.squeeze([axis])

Remove one or more length-1 axes from array.

Array.std([axis, dtype, out, ddof, ...])

Compute the standard deviation along a given axis.

Array.sum([axis, dtype, out, keepdims, ...])

Sum of the elements of the array over a given axis.

Array.swapaxes(axis1, axis2)

Swap two axes of an array.

Array.take(indices[, axis, out, mode, ...])

Take elements from an array.

Array.to_device(device, *[, stream])

Return a copy of the array on the specified device

Array.trace([offset, axis1, axis2, dtype, out])

Return the sum along the diagonal.

Array.transpose(*args)

Returns a copy of the array with axes transposed.

Array.var([axis, dtype, out, ddof, ...])

Compute the variance along a given axis.

Array.view([dtype, type])

Return a bitwise copy of the array, viewed as a new dtype.

Array.T

Compute the all-axis array transpose.

Array.mT

Compute the (batched) matrix transpose.

Vectorization (vmap)#

vmap(fun[, in_axes, out_axes, axis_name, ...])

Vectorizing map.

numpy.vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

Parallelization (pmap)#

pmap(fun[, axis_name, in_axes, out_axes, ...])

Parallel map with support for collective operations.

devices([backend])

Returns a list of all devices for a given backend.

local_devices([process_index, backend, host_id])

Like jax.devices(), but only returns devices local to a given process.

process_index([backend])

Returns the integer process index of this process.

device_count([backend])

Returns the total number of devices.

local_device_count([backend])

Returns the number of devices addressable by this process.

process_count([backend])

Returns the number of JAX processes associated with the backend.

process_indices([backend])

Returns the list of all JAX process indices associated with the backend.

Callbacks#

pure_callback(callback, result_shape_dtypes, ...)

Calls a pure Python callback.

experimental.io_callback(callback, ...[, ...])

Calls an impure Python callback.

debug.callback(callback, *args[, ordered])

Calls a stageable Python callback.

debug.print(fmt, *args[, ordered])

Prints values and works in staged out JAX functions.

Miscellaneous#

Device

A descriptor of an available device.

print_environment_info([return_string])

Returns a string containing local environment & JAX installation information.

live_arrays([platform])

Return all live arrays in the backend for platform.

clear_caches()

Clear all compilation and staging caches.