Public API: jax package#

Subpackages#

Configuration#

config

check_tracer_leaks

Context manager for jax_check_tracer_leaks config option.

checking_leaks

Context manager for jax_check_tracer_leaks config option.

debug_nans

Context manager for jax_debug_nans config option.

debug_infs

Context manager for jax_debug_infs config option.

default_device

Context manager for jax_default_device config option.

default_matmul_precision

Context manager for jax_default_matmul_precision config option.

default_prng_impl

Context manager for jax_default_prng_impl config option.

enable_checks

Context manager for jax_enable_checks config option.

enable_custom_prng

Context manager for jax_enable_custom_prng config option (transient).

enable_custom_vjp_by_custom_transpose

Context manager for jax_enable_custom_vjp_by_custom_transpose config option (transient).

log_compiles

Context manager for jax_log_compiles config option.

numpy_rank_promotion

Context manager for jax_numpy_rank_promotion config option.

transfer_guard(new_val)

A contextmanager to control the transfer guard level for all transfers.

Just-in-time compilation (jit)#

jit(fun[, in_shardings, out_shardings, ...])

Sets up fun for just-in-time compilation with XLA.

disable_jit([disable])

Context manager that disables jit() behavior under its dynamic context.

ensure_compile_time_eval()

Context manager to ensure evaluation at trace/compile time (or error).

make_jaxpr([axis_env, return_shape, ...])

Creates a function that produces its jaxpr given example args.

eval_shape(fun, *args, **kwargs)

Compute the shape/dtype of fun without any FLOPs.

ShapeDtypeStruct(shape, dtype, *[, ...])

A container for the shape, dtype, and other static attributes of an array.

device_put(x[, device, src, donate, may_alias])

Transfers x to device.

device_get(x)

Transfer x to host.

default_backend()

Returns the platform name of the default XLA backend.

named_call(fun, *[, name])

Adds a user specified name to a function when staging out JAX computations.

named_scope(name)

A context manager that adds a user specified name to the JAX name stack.

block_until_ready(x)

Tries to call a block_until_ready method on pytree leaves.

make_mesh(axis_shapes, axis_names, *[, ...])

Creates an efficient mesh with the shape and axis names specified.

Automatic differentiation#

grad(fun[, argnums, has_aux, holomorphic, ...])

Creates a function that evaluates the gradient of fun.

value_and_grad(fun[, argnums, has_aux, ...])

Create a function that evaluates both fun and the gradient of fun.

jacobian(fun[, argnums, has_aux, ...])

Alias of jax.jacrev().

jacfwd(fun[, argnums, has_aux, holomorphic])

Jacobian of fun evaluated column-by-column using forward-mode AD.

jacrev(fun[, argnums, has_aux, holomorphic, ...])

Jacobian of fun evaluated row-by-row using reverse-mode AD.

hessian(fun[, argnums, has_aux, holomorphic])

Hessian of fun as a dense array.

jvp(fun, primals, tangents[, has_aux])

Computes a (forward-mode) Jacobian-vector product of fun.

linearize()

Produces a linear approximation to fun using jvp() and partial eval.

linear_transpose(fun, *primals[, reduce_axes])

Transpose a function that is promised to be linear.

vjp() ))

Compute a (reverse-mode) vector-Jacobian product of fun.

custom_gradient(fun)

Convenience function for defining custom VJP rules (aka custom gradients).

closure_convert(fun, *example_args)

Closure conversion utility, for use with higher-order custom derivatives.

checkpoint(fun, *[, prevent_cse, policy, ...])

Make fun recompute internal linearization points when differentiated.

Customization#

custom_jvp#

custom_jvp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom JVP rule definition.

custom_jvp.defjvp(jvp[, symbolic_zeros])

Define a custom JVP rule for the function represented by this instance.

custom_jvp.defjvps(*jvps)

Convenience wrapper for defining JVPs for each argument separately.

custom_vjp#

custom_vjp(fun[, nondiff_argnums])

Set up a JAX-transformable function for a custom VJP rule definition.

custom_vjp.defvjp(fwd, bwd[, ...])

Define a custom VJP rule for the function represented by this instance.

custom_batching#

custom_batching.custom_vmap(fun)

Customize the vmap behavior of a JAX-transformable function.

custom_batching.custom_vmap.def_vmap(vmap_rule)

Define the vmap rule for this custom_vmap function.

custom_batching.sequential_vmap(f)

A special case of custom_vmap that uses a loop.

jax.Array (jax.Array)#

Array()

Array base class for JAX

make_array_from_callback(shape, sharding, ...)

Returns a jax.Array via data fetched from data_callback.

make_array_from_single_device_arrays(shape, ...)

Returns a jax.Array from a sequence of jax.Arrays each on a single device.

make_array_from_process_local_data(sharding, ...)

Creates distributed tensor using the data available in process.

Array properties and methods#

Array.addressable_shards

List of addressable shards.

Array.all([axis, out, keepdims, where])

Test whether all array elements along a given axis evaluate to True.

Array.any([axis, out, keepdims, where])

Test whether any array elements along a given axis evaluate to True.

Array.argmax([axis, out, keepdims])

Return the index of the maximum value.

Array.argmin([axis, out, keepdims])

Return the index of the minimum value.

Array.argpartition(kth[, axis])

Return the indices that partially sort the array.

Array.argsort([axis, kind, order, stable, ...])

Return the indices that sort the array.

Array.astype(dtype[, copy, device])

Copy the array and cast to a specified dtype.

Array.at

Helper property for index update functionality.

Array.choose(choices[, out, mode])

Construct an array choosing from elements of multiple arrays.

Array.clip([min, max])

Return an array whose values are limited to a specified range.

Array.compress(condition[, axis, out, size, ...])

Return selected slices of this array along given axis.

Array.committed

Whether the array is committed or not.

Array.conj()

Return the complex conjugate of the array.

Array.conjugate()

Return the complex conjugate of the array.

Array.copy()

Return a copy of the array.

Array.copy_to_host_async()

Copies an Array to the host asynchronously.

Array.cumprod([axis, dtype, out])

Return the cumulative product of the array.

Array.cumsum([axis, dtype, out])

Return the cumulative sum of the array.

Array.device

Array API-compatible device attribute.

Array.diagonal([offset, axis1, axis2])

Return the specified diagonal from the array.

Array.dot(b, *[, precision, ...])

Compute the dot product of two arrays.

Array.dtype

The data type (numpy.dtype) of the array.

Array.flat

Use flatten() instead.

Array.flatten([order])

Flatten array into a 1-dimensional shape.

Array.global_shards

List of global shards.

Array.imag

Return the imaginary part of the array.

Array.is_fully_addressable

Is this Array fully addressable?

Array.is_fully_replicated

Is this Array fully replicated?

Array.item(*args)

Copy an element of an array to a standard Python scalar and return it.

Array.itemsize

Length of one array element in bytes.

Array.max([axis, out, keepdims, initial, where])

Return the maximum of array elements along a given axis.

Array.mean([axis, dtype, out, keepdims, where])

Return the mean of array elements along a given axis.

Array.min([axis, out, keepdims, initial, where])

Return the minimum of array elements along a given axis.

Array.nbytes

Total bytes consumed by the elements of the array.

Array.ndim

The number of dimensions in the array.

Array.nonzero(*[, fill_value, size])

Return indices of nonzero elements of an array.

Array.prod([axis, dtype, out, keepdims, ...])

Return product of the array elements over a given axis.

Array.ptp([axis, out, keepdims])

Return the peak-to-peak range along a given axis.

Array.ravel([order])

Flatten array into a 1-dimensional shape.

Array.real

Return the real part of the array.

Array.repeat(repeats[, axis, ...])

Construct an array from repeated elements.

Array.reshape(*args[, order])

Returns an array containing the same data with a new shape.

Array.round([decimals, out])

Round array elements to a given decimal.

Array.searchsorted(v[, side, sorter, method])

Perform a binary search within a sorted array.

Array.shape

The shape of the array.

Array.sharding

The sharding for the array.

Array.size

The total number of elements in the array.

Array.sort([axis, kind, order, stable, ...])

Return a sorted copy of an array.

Array.squeeze([axis])

Remove one or more length-1 axes from array.

Array.std([axis, dtype, out, ddof, ...])

Compute the standard deviation along a given axis.

Array.sum([axis, dtype, out, keepdims, ...])

Sum of the elements of the array over a given axis.

Array.swapaxes(axis1, axis2)

Swap two axes of an array.

Array.take(indices[, axis, out, mode, ...])

Take elements from an array.

Array.to_device(device, *[, stream])

Return a copy of the array on the specified device

Array.trace([offset, axis1, axis2, dtype, out])

Return the sum along the diagonal.

Array.transpose(*args)

Returns a copy of the array with axes transposed.

Array.var([axis, dtype, out, ddof, ...])

Compute the variance along a given axis.

Array.view([dtype, type])

Return a bitwise copy of the array, viewed as a new dtype.

Array.T

Compute the all-axis array transpose.

Array.mT

Compute the (batched) matrix transpose.

Vectorization (vmap)#

vmap(fun[, in_axes, out_axes, axis_name, ...])

Vectorizing map.

numpy.vectorize(pyfunc, *[, excluded, signature])

Define a vectorized function with broadcasting.

Parallelization (pmap)#

pmap(fun[, axis_name, in_axes, out_axes, ...])

Parallel map with support for collective operations.

devices([backend])

Returns a list of all devices for a given backend.

local_devices([process_index, backend, host_id])

Like jax.devices(), but only returns devices local to a given process.

process_index([backend])

Returns the integer process index of this process.

device_count([backend])

Returns the total number of devices.

local_device_count([backend])

Returns the number of devices addressable by this process.

process_count([backend])

Returns the number of JAX processes associated with the backend.

process_indices([backend])

Returns the list of all JAX process indices associated with the backend.

Callbacks#

pure_callback(callback, result_shape_dtypes, ...)

Calls a pure Python callback.

experimental.io_callback(callback, ...[, ...])

Calls an impure Python callback.

debug.callback(callback, *args[, ordered])

Calls a stageable Python callback.

debug.print(fmt, *args[, ordered])

Prints values and works in staged out JAX functions.

Miscellaneous#

Device

A descriptor of an available device.

print_environment_info([return_string])

Returns a string containing local environment & JAX installation information.

live_arrays([platform])

Return all live arrays in the backend for platform.

clear_caches()

Clear all compilation and staging caches.