Public API: jax package¶
Subpackages¶
Justintime compilation (jit
)¶

Sets up 
Context manager that disables 


Creates a function that produces its XLA computation given example args. 

Creates a function that produces its jaxpr given example args. 

Compute the shape/dtype of 

Transfers 

Adds a user specified name to a function when staging out JAX computations. 
Automatic differentiation¶

Creates a function which evaluates the gradient of 

Create a function which evaluates both 

Jacobian of 

Jacobian of 

Hessian of 

Computes a (forwardmode) Jacobianvector product of 

Produces a linear approximation to 

Transpose a function that is promised to be linear. 

Compute a (reversemode) vectorJacobian product of 

Set up a JAXtransformable function for a custom JVP rule definition. 

Set up a JAXtransformable function for a custom VJP rule definition. 

Make 
Vectorization (vmap
)¶

Vectorizing map. 

Define a vectorized function with broadcasting. 
Parallelization (pmap
)¶

Parallel map with support for collective operations. 

Returns a list of all devices for a given backend. 

Like 

Returns the integer host ID of this host. 

Returns a sorted list of all host IDs. 

Returns the total number of devices. 

Returns the number of devices on this host. 

Returns the number of hosts. 

jax.
jit
(fun, static_argnums=(), device=None, backend=None, donate_argnums=())[source]¶ Sets up
fun
for justintime compilation with XLA. Parameters
fun (
Callable
[…, ~T]) – Function to be jitted. Should be a pure function, as sideeffects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated bystatic_argnums
can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.static_argnums (
Union
[int
,Iterable
[int
]]) – An int or collection of ints specifying which positional arguments to treat as static (compiletime constant). Operations that only depend on static arguments will be constantfolded in Python (during tracing), and so the corresponding argument values can be any Python object. Static arguments should be hashable, meaning both__hash__
and__eq__
are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated bystatic_argnums
then an error is raised. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via
jax.devices()
.) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0]
.backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:'cpu'
,'gpu'
, or'tpu'
.donate_argnums (
Union
[int
,Iterable
[int
]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.
 Return type
Callable
[…, ~T] Returns
A wrapped version of
fun
, set up for justintime compilation.
In the following example,
selu
can be compiled into a single fused kernel by XLA:>>> import jax >>> >>> @jax.jit ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x)  alpha) >>> >>> key = jax.random.PRNGKey(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [0.54485 0.27744 0.29255 0.91421 0.62452 0.24748 0.85743 0.78232 0.76827 0.59566 ]

jax.
disable_jit
()[source]¶ Context manager that disables
jit()
behavior under its dynamic context.For debugging it is useful to have a mechanism that disables
jit()
everywhere in a dynamic context.Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a
ShapedArray
instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign sideeffecting operation in a jitted function, like a print:>>> import jax >>> >>> @jax.jit ... def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)> [5 7 9]
Here
y
has been abstracted byjit()
to aShapedArray
, which represents an array with a fixed shape and type but an arbitrary value. The value ofy
is also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use thedisable_jit()
context manager:>>> import jax >>> >>> with jax.disable_jit(): ... print(f(jax.numpy.array([1, 2, 3]))) ... Value of y is [2 4 6] [5 7 9]

jax.
xla_computation
(fun, static_argnums=(), axis_env=None, in_parts=None, out_parts=None, backend=None, tuple_args=False, instantiate_const_outputs=None, return_shape=False, donate_argnums=())[source]¶ Creates a function that produces its XLA computation given example args.
 Parameters
fun (
Callable
) – Function from which to form XLA computations.static_argnums (
Union
[int
,Iterable
[int
]]) – See thejax.jit()
docstring.axis_env (
Optional
[Sequence
[Tuple
[Any
,int
]]]) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications ofjax.pmap()
. See the examples below.in_parts – Optional, how each argument to
fun
should partitioned or replicated. This is used to specify partitioned XLA computations, seesharded_jit
for more info.out_parts – Optional, how each output of
fun
should partitioned or replicated. This is used to specify partitioned XLA computations, seesharded_jit
for more info.backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:'cpu'
,'gpu'
, or'tpu'
.tuple_args (
bool
) – Optional bool, defaults toFalse
. IfTrue
, the resulting XLA computation will have a single tuple argument that is unpacked into the specified function arguments. If None, tupling will be enabled when there are more than 100 arguments, since some platforms have limits on argument arity.instantiate_const_outputs (
Optional
[bool
]) – Deprecated argument, does nothing.return_shape (
bool
) – Optional boolean, defaults toFalse
. IfTrue
, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output offun
and where the leaves are objects withshape
anddtype
attributes representing the corresponding types of the output leaves.donate_argnums (
Union
[int
,Iterable
[int
]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.
 Return type
 Returns
A wrapped version of
fun
that when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods likeas_hlo_text
,as_serialized_hlo_module_proto
, andas_hlo_dot_graph
. If the argumentreturn_shape
isTrue
, then the wrapped function returns a pair where the first element is the XLA Computation and the second element is a pytree representing the structure, shapes, and dtypes of the output offun
.
For example:
>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> c = jax.xla_computation(f)(3.) >>> print(c.as_hlo_text()) HloModule xla_computation_f.6 ENTRY xla_computation_f.6 { constant.2 = pred[] constant(false) parameter.1 = f32[] parameter(0) cosine.3 = f32[] cosine(parameter.1) sine.4 = f32[] sine(cosine.3) ROOT tuple.5 = (f32[]) tuple(sine.4) }
Here’s an example that involves a parallel collective and axis name:
>>> def f(x): return x  jax.lax.psum(x, 'i') >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) >>> print(c.as_hlo_text()) HloModule jaxpr_computation.9 primitive_computation.3 { parameter.4 = s32[] parameter(0) parameter.5 = s32[] parameter(1) ROOT add.6 = s32[] add(parameter.4, parameter.5) } ENTRY jaxpr_computation.9 { tuple.1 = () tuple() parameter.2 = s32[] parameter(0) allreduce.7 = s32[] allreduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 ROOT subtract.8 = s32[] subtract(parameter.2, allreduce.7) }
Notice the
replica_groups
that were generated. Here’s an example that generates more interestingreplica_groups
:>>> from jax import lax >>> def g(x): ... rowsum = lax.psum(x, 'i') ... colsum = lax.psum(x, 'j') ... allsum = lax.psum(x, ('i', 'j')) ... return rowsum, colsum, allsum ... >>> axis_env = [('i', 4), ('j', 2)] >>> c = xla_computation(g, axis_env=axis_env)(5.) >>> print(c.as_hlo_text()) HloModule jaxpr_computation__1.19 [removed uninteresting text here] ENTRY jaxpr_computation__1.19 { tuple.1 = () tuple() parameter.2 = f32[] parameter(0) allreduce.7 = f32[] allreduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 allreduce.12 = f32[] allreduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 allreduce.17 = f32[] allreduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 ROOT tuple.18 = (f32[], f32[], f32[]) tuple(allreduce.7, allreduce.12, allreduce.17) }

jax.
make_jaxpr
(fun, static_argnums=(), return_shape=False)[source]¶ Creates a function that produces its jaxpr given example args.
 Parameters
fun (
Callable
) – The function whosejaxpr
is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.static_argnums (
Union
[int
,Iterable
[int
]]) – See thejax.jit()
docstring.return_shape (
bool
) – Optional boolean, defaults toFalse
. IfTrue
, the wrapped function returns a pair where the first element is thejaxpr
and the second element is a pytree with the same structure as the output offun
and where the leaves are objects withshape
anddtype
attributes representing the corresponding types of the output leaves.
 Return type
Callable
[…,ClosedJaxpr
] Returns
A wrapped version of
fun
that when applied to example arguments returns aClosedJaxpr
representation offun
on those arguments. If the argumentreturn_shape
isTrue
, then the returned function instead returns a pair where the first element is theClosedJaxpr
representation offun
and the second element is a pytree representing the structure, shape, and dtypes of the output offun
.
A
jaxpr
is JAX’s intermediate representation for program traces. Thejaxpr
language is based on the simplytyped firstorder lambda calculus with letbindings.make_jaxpr()
adapts a function to return itsjaxpr
, which we can inspect to understand what JAX is doing internally. Thejaxpr
returned is a trace offun
abstracted toShapedArray
level. Other levels of abstraction exist internally.We do not describe the semantics of the
jaxpr
language in detail here, but instead give a few examples.>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) 0.83602 >>> jax.make_jaxpr(f)(3.0) { lambda ; a. let b = cos a c = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a. let b = cos a c = sin a _ = sin b d = cos b e = mul 1.0 d f = neg e g = mul f c in (g,) }

jax.
eval_shape
(fun, *args, **kwargs)[source]¶ Compute the shape/dtype of
fun
without any FLOPs.This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) return jax.tree_util.tree_map(shape_dtype_struct, out) def shape_dtype_struct(x): return ShapeDtypeStruct(x.shape, x.dtype) class ShapeDtypeStruct: __slots__ = ["shape", "dtype"] def __init__(self, shape, dtype): self.shape = shape self.dtype = dtype
In particular, the output is a pytree of objects that have
shape
anddtype
attributes, but nothing else about them is guaranteed by the API.But instead of applying
fun
directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.Using
eval_shape()
can also catch shape errors, and will raise same shape errors as evaluatingfun(*args, **kwargs)
. Parameters
fun (
Callable
) – The function whose output shape should be evaluated.*args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the
shape
anddtype
attributes are accessed, only values that ducktype arrays are required, rather than real ndarrays. The ducktyped objects cannot be namedtuples because those are treated as standard Python containers. See the example below.**kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in
args
, array values need only be ducktyped to haveshape
anddtype
attributes.
For example:
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> class MyArgArray(object): ... def __init__(self, shape, dtype): ... self.shape = shape ... self.dtype = dtype ... >>> A = MyArgArray((2000, 3000), jnp.float32) >>> x = MyArgArray((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32

jax.
device_put
(x, device=None)[source]¶ Transfers
x
todevice
. Parameters
x – An array, scalar, or (nested) standard Python container thereof.
device (
Optional
[Device
]) – The (optional)Device
to whichx
should be transferred. If given, then the result is committed to the device.
If the
device
parameter isNone
, then this operation behaves like the identity function if the operand is on any device already, otherwise it transfers the data to the default device, uncommitted.For more details on data placement see the FAQ on data placement.
 Returns
A copy of
x
that resides ondevice
.

jax.
grad
(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]¶ Creates a function which evaluates the gradient of
fun
. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers. Argument arrays in the positions specified byargnums
must be of inexact (i.e., floatingpoint or complex) type. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool
) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vectorspace dtype (float0). Default False.
 Return type
 Returns
A function with the same arguments as
fun
, that evaluates the gradient offun
. Ifargnums
is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_aux
is True then a pair of (gradient, auxiliary_data) is returned.
For example:
>>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043

jax.
value_and_grad
(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False)[source]¶ Create a function which evaluates both
fun
and the gradient offun
. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool
) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vectorspace dtype (float0). Default False.
 Return type
 Returns
A function with the same arguments as
fun
that evaluates bothfun
and the gradient offun
and returns them as a pair (a twoelement tuple). Ifargnums
is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.

jax.
jacfwd
(fun, argnums=0, holomorphic=False)[source]¶ Jacobian of
fun
evaluated columnbycolumn using forwardmode AD. Parameters
fun (
Callable
) – Function whose Jacobian is to be computed.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.
 Return type
 Returns
A function with the same arguments as
fun
, that evaluates the Jacobian offun
using forwardmode automatic differentiation.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2  2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. 2. ] [ 1.6209 0. 0.84147]]

jax.
jacrev
(fun, argnums=0, holomorphic=False, allow_int=False)[source]¶ Jacobian of
fun
evaluated rowbyrow using reversemode AD. Parameters
fun (
Callable
) – Function whose Jacobian is to be computed.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.allow_int (
bool
) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vectorspace dtype (float0). Default False.
 Return type
 Returns
A function with the same arguments as
fun
, that evaluates the Jacobian offun
using reversemode automatic differentiation.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2  2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. 2. ] [ 1.6209 0. 0.84147]]

jax.
hessian
(fun, argnums=0, holomorphic=False)[source]¶ Hessian of
fun
as a dense array. Parameters
fun (
Callable
) – Function whose Hessian is to be computed. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.
 Return type
 Returns
A function with the same arguments as
fun
, that evaluates the Hessian offun
.
>>> import jax >>> >>> g = lambda x: x[0]**3  2*x[0]*x[1]  x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6. 2.] [ 2. 480.]]
hessian()
is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure ofjax.hessian(fun)(x)
is given by forming a tree product of the structure offun(x)
with a tree product of two copies of the structure ofx
. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:>>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': DeviceArray([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': DeviceArray([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of
jax.hessian(fun)(x)
corresponds to a leaf offun(x)
and a pair of leaves ofx
. For each leaf injax.hessian(fun)(x)
, if the corresponding array leaf offun(x)
has shape(out_1, out_2, ...)
and the corresponding array leaves ofx
have shape(in_1_1, in_1_2, ...)
and(in_2_1, in_2_2, ...)
respectively, then the Hessian leaf has shape(out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)
. In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.In particular, an array is produced (with no pytrees involved) when the function input
x
and outputfun(x)
are each a single array, as in theg
example above. Iffun(x)
has shape(out1, out2, ...)
andx
has shape(in1, in2, ...)
thenjax.hessian(fun)(x)
has shape(out1, out2, ..., in1, in2, ..., in1, in2, ...)
. To flatten pytrees into 1D vectors, consider usingjax.flatten_util.flatten_pytree()
.

jax.
jvp
(fun, primals, tangents)[source]¶ Computes a (forwardmode) Jacobianvector product of
fun
. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be either a tuple or a list of arguments, and its length should equal to the number of positional parameters offun
.tangents – The tangent vector for which the Jacobianvector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals
.
 Return type
 Returns
A
(primals_out, tangents_out)
pair, whereprimals_out
isfun(*primals)
, andtangents_out
is the Jacobianvector product offunction
evaluated atprimals
withtangents
. Thetangents_out
value has the same Python tree structure and shapes asprimals_out
.
For example:
>>> import jax >>> >>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(y) 0.09983342 >>> print(v) 0.19900084

jax.
linearize
(fun, *primals)[source]¶ Produces a linear approximation to
fun
usingjvp()
and partial eval. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters offun
.
 Return type
 Returns
A pair where the first element is the value of
f(*primals)
and the second element is a function that evaluates the (forwardmode) Jacobianvector product offun
evaluated atprimals
without redoing the linearization work.
In terms of values computed,
linearize()
behaves much like a curriedjvp()
, where these two code blocks compute the same values:y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
However, the difference is that
linearize()
uses partial evaluation so that the functionf
is not relinearized on calls tof_jvp
. In general that means the memory usage scales with the size of the computation, much like in reversemode. (Indeed,linearize()
has a similar signature tovjp()
!)This function is mainly useful if you want to apply
f_jvp
multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize usingvmap()
, as in:pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using
vmap()
andjvp()
together like this we avoid the storedlinearization memory cost that scales with the depth of the computation, which is incurred by bothlinearize()
andvjp()
.Here’s a more complete example of using
linearize()
:>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (DeviceArray(3.26819, dtype=float32), DeviceArray(5.00753, dtype=float32)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) 5.007528 >>> print(f_jvp(4.)) 6.676704

jax.
linear_transpose
(fun, *primals)[source]¶ Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to
vjp
, but avoids the overhead of computing the forward pass.The outputs of the transposed function will always have the exact same dtypes as
primals
, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes inprimals
that match the full range of desired outputs from the transposed function. Integer dtypes are not supported. Parameters
fun (
Callable
) – the linear function to be transposed.*primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of
fun(*primals)
. These arguments may be real scalars/ndarrays, but that is not required: only theshape
anddtype
attributes are accessed. See below for an example. (Note that the ducktyped objects cannot be namedtuples because those are treated as standard Python containers.)
 Return type
 Returns
A callable that calculates the transpose of
fun
. Valid input into this function must have the same shape/dtypes/structure as the result offun(*primals)
. Output will be a tuple, with the same shape/dtypes/structure asprimals
.
>>> import jax >>> import types >>> >>> f = lambda x, y: 0.5 * x  0.5 * y >>> scalar = types.SimpleNamespace(shape=(), dtype=np.float32) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (DeviceArray(0.5, dtype=float32), DeviceArray(0.5, dtype=float32))

jax.
vjp
(fun, *primals, has_aux=False)[source]¶ Compute a (reversemode) vectorJacobian product of
fun
.grad()
is implemented as a special case ofvjp()
. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – A sequence of primal values at which the Jacobian of
fun
should be evaluated. The length ofprimals
should be equal to the number of positional parameters tofun
. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
 Return type
 Returns
If
has_aux
isFalse
, returns a(primals_out, vjpfun)
pair, whereprimals_out
isfun(*primals)
.vjpfun
is a function from a cotangent vector with the same shape asprimals_out
to a tuple of cotangent vectors with the same shape asprimals
, representing the vectorJacobian product offun
evaluated atprimals
. Ifhas_aux
isTrue
, returns a(primals_out, vjpfun, aux)
tuple whereaux
is the auxiliary data returned byfun
.
>>> import jax >>> >>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((0.7, 0.3)) >>> print(xbar) 0.61430776 >>> print(ybar) 0.2524413

class
jax.
custom_jvp
(fun, nondiff_argnums=())[source]¶ Set up a JAXtransformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like
jax.jvp()
orjax.grad()
) is applied, in which case a custom usersupplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.There are two instance methods available for defining the custom JVP rule:
defjvp()
for defining a single custom JVP rule for all the function’s inputs, and for conveniencedefjvps()
, which wrapsdefjvp()
, and allows you to provide seperate definitions for the partial derivatives of the function w.r.t. each of its arguments.For example:
import jax.numpy as jnp @jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
For a more detailed introduction, see the tutorial.

defjvp
(jvp)[source]¶ Define a custom JVP rule for the function represented by this instance.
 Parameters
jvp – a Python callable representing the custom JVP rule. When there are no
nondiff_argnums
, thejvp
function should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples is equal to the number of parameters of thecustom_jvp
function. Thejvp
function should produce as output a pair where the first element is the primal output and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof. Returns
None.
Example:
import jax.numpy as jnp @jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out

defjvps
(*jvps)[source]¶ Convenience wrapper for defining JVPs for each argument separately.
This convenience wrapper cannot be used together with
nondiff_argnums
. Parameters
*jvps – a sequence of functions, one for each positional argument of the
custom_jvp
function. Each function takes as arguments the tangent value for the corresponding primal input, the primal output, and the primal inputs. See the example below. Returns
None.
Example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)


class
jax.
custom_vjp
(fun, nondiff_argnums=())[source]¶ Set up a JAXtransformable function for a custom VJP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reversemode differentiation transformation (like
jax.grad()
) is applied, in which case a custom usersupplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method,defvjp()
, which may be used to define the custom VJP rule.This decorator precludes the use of forwardmode automatic differentiation.
For example:
import jax.numpy as jnp @jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
For a more detailed introduction, see the tutorial.

defvjp
(fwd, bwd)[source]¶ Define a custom VJP rule for the function represented by this instance.
 Parameters
fwd – a Python callable representing the forward pass of the custom VJP rule. When there are no
nondiff_argnums
, thefwd
function has the same input signature as the underlying primal function. It should return as output a pair, where the first element represents the primal output and the second element represents any “residual” values to store from the forward pass for use on the backward pass by the functionbwd
. Input arguments and elements of the output pair may be arrays or nested tuples/lists/dicts thereof.bwd – a Python callable representing the backward pass of the custom VJP rule. When there are no
nondiff_argnums
, thebwd
function takes two arguments, where the first is the “residual” values produced on the forward pass byfwd
, and the second is the output cotangent with the same structure as the primal function output. The output ofbwd
must be a tuple of length equal to the number of arguments of the primal function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments.
 Returns
None.
Example:
import jax.numpy as jnp @jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)


jax.
checkpoint
(fun, concrete=False)[source]¶ Make
fun
recompute internal linearization points when differentiated.The
jax.checkpoint()
decorator, aliased tojax.remat
, provides a way to trade off computation time and memory cost in the context of automatic differentiation, especially with reversemode autodiff likejax.grad()
andjax.vjp()
but also withjax.linearize()
.When differentiating a function in reversemode, by default all the linearization points (e.g. inputs to elementwise nonlinear primitive operations) are stored when evaluating the forward pass so that they can be reused on the backward pass. This evaluation strategy can lead to a high memory cost, or even to poor performance on hardware accelerators where memory access is much more expensive than FLOPs.
An alternative evaluation strategy is for some of the linearization points to be recomputed (i.e. rematerialized) rather than stored. This approach can reduce memory usage at the cost of increased computation.
This function decorator produces a new version of
fun
which follows the rematerialization strategy rather than the default storeeverything strategy. That is, it returns a new version offun
which, when differentiated, doesn’t store any of its intermediate linearization points. Instead, these linearization points are recomputed from the function’s saved inputs.See the examples below.
 Parameters
fun (
Callable
) – Function for which the autodiff evaluation strategy is to be changed from the default of storing all intermediate linearization points to recomputing them. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.concrete (
bool
) – Optional, boolean indicating whetherfun
may involve valuedependent Python control flow (default False). Support for such control flow is optional, and disabled by default, because in some edgecase compositions withjax.jit()
it can lead to some extra computation.
 Return type
 Returns
A function (callable) with the same input/output behavior as
fun
but which, when differentiated using e.g.jax.grad()
,jax.vjp()
, orjax.linearize()
, recomputes rather than stores intermediate linearization points, thus potentially saving memory at the cost of extra computation.
Here is a simple example:
>>> import jax >>> import jax.numpy as jnp
>>> @jax.checkpoint ... def g(x): ... y = jnp.sin(x) ... z = jnp.sin(y) ... return z ... >>> jax.grad(g)(2.0) DeviceArray(0.25563914, dtype=float32)
Here, the same value is produced whether or not the
jax.checkpoint()
decorator is present. But when usingjax.checkpoint()
, the valuejnp.sin(2.0)
is computed twice: once on the forward pass, and once on the backward pass. The valuesjnp.cos(2.0)
andjnp.cos(jnp.sin(2.0))
are also computed twice. Without using the decorator, bothjnp.cos(2.0)
andjnp.cos(jnp.sin(2.0))
would be stored and reused.The
jax.checkpoint()
decorator can be applied recursively to express sophisticated autodiff rematerialization strategies. For example:>>> def recursive_checkpoint(funs): ... if len(funs) == 1: ... return funs[0] ... elif len(funs) == 2: ... f1, f2 = funs ... return lambda x: f1(f2(x)) ... else: ... f1 = recursive_checkpoint(funs[:len(funs)//2]) ... f2 = recursive_checkpoint(funs[len(funs)//2:]) ... return lambda x: f1(jax.checkpoint(f2)(x)) ...

jax.
vmap
(fun, in_axes=0, out_axes=0, axis_name=None)[source]¶ Vectorizing map. Creates a function which maps
fun
over argument axes. Parameters
fun (
Callable
[…, ~T]) – Function to be mapped over additional axes.in_axes –
An integer, None, or (nested) standard Python container (tuple/list/dict) thereof specifying which input array axes to map over.
If each positional argument to
fun
is an array, thenin_axes
can be an integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments tofun
. An integer orNone
indicates which array axis to map over for all arguments (withNone
indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range[ndim, ndim)
for each array, wherendim
is the number of dimensions of the corresponding input array.If the positional arguments to
fun
are container types, the corresponding element ofin_axes
can itself be a matching container, so that distinct array axes can be mapped for different container elements.in_axes
must be a container tree prefix of the positional argument tuple passed tofun
.At least one positional argument must have
in_axes
not None. The sizes of the mapped input axes for all mapped positional arguments must all be equal.out_axes – An integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a nonNone
out_axes
specification. Axis integers must be in the range[ndim, ndim)
for each output array, wherendim
is the number of dimensions of the array returned by thevmap()
ed function, which is one more than the number of dimensions of the corresponding array returned byfun
.
 Return type
Callable
[…, ~T] Returns
Batched/vectorized version of
fun
with arguments that correspond to those offun
, but with extra array axes at positions indicated byin_axes
, and a return value that corresponds to that offun
, but with extra array axes at positions indicated byout_axes
.
For example, we can implement a matrixmatrix product using a vector dot product:
>>> import jax.numpy as jnp >>> >>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) > [] >>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) > [b] (b is the mapped axis) >>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) > [b,c] (c is the mapped axis)
Here we use
[a,b]
to indicate an array with shape (a,b). Here are some variants:>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) > [b] (b is the mapped axis) >>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) > [b] (b is the mapped axis) >>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) > [c,b] (c is the mapped axis)
Here’s an example of using container types in
in_axes
to specify which axes of the container elements to map over:>>> A, B, C, D = 2, 3, 4, 5 >>> x = jnp.ones((A, B)) >>> y = jnp.ones((B, C)) >>> z = jnp.ones((C, D)) >>> def foo(tree_arg): ... x, (y, z) = tree_arg ... return jnp.dot(x, jnp.dot(y, z)) >>> tree = (x, (y, z)) >>> print(foo(tree)) [[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]] >>> from jax import vmap >>> K = 6 # batch size >>> x = jnp.ones((K, A, B)) # batch axis in different locations >>> y = jnp.ones((B, K, C)) >>> z = jnp.ones((C, D, K)) >>> tree = (x, (y, z)) >>> vfoo = vmap(foo, in_axes=((0, (1, 2)),)) >>> print(vfoo(tree).shape) (6, 2, 5)
Here’s another example using container types in
in_axes
, this time a dictionary, to specify the elements of the container to map over:>>> dct = {'a': 0., 'b': jnp.arange(5.)} >>> x = 1. >>> def foo(dct, x): ... return dct['a'] + dct['b'] + x >>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x) >>> print(out) [1. 2. 3. 4. 5.]
The results of a vectorized function can be mapped or unmapped. For example, the function below returns a pair with the first element mapped and the second unmapped. Only for unmapped results we can specify
out_axes
to beNone
(to keep it unmapped).>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.)) (DeviceArray([4., 5.], dtype=float32), 8.0)
If the
out_axes
is specified for an unmapped result, the result is broadcast across the mapped axis:>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.)) (DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32))
If the
out_axes
is specified for a mapped result, the result is transposed accordingly.

jax.numpy.
vectorize
(pyfunc, *, excluded=frozenset({}), signature=None)[source]¶ Define a vectorized function with broadcasting.
vectorize()
is a convenience wrapper for defining vectorized functions with broadcasting, in the style of NumPy’s generalized universal functions. It allows for defining functions that are automatically repeated across any leading dimensions, without the implementation of the function needing to be concerned about how to handle higher dimensional inputs.jax.numpy.vectorize()
has the same interface asnumpy.vectorize
, but it is syntactic sugar for an autobatching transformation (vmap()
) rather than a Python loop. This should be considerably more efficient, but the implementation must be written in terms of functions that act on JAX arrays. Parameters
pyfunc – function to vectorize.
excluded – optional set of integers representing positional arguments for which the function will not be vectorized. These will be passed directly to
pyfunc
unmodified.signature – optional generalized universal function signature, e.g.,
(m,n),(n)>(m)
for vectorized matrixvector multiplication. If provided,pyfunc
will be called with (and expected to return) arrays with shapes given by the size of corresponding core dimensions. By default, pyfunc is assumed to take scalars arrays as input and output.
 Returns
Vectorized version of the given function.
Here a few examples of how one could write vectorized linear algebra routines using
vectorize()
:import jax.numpy as jnp from functools import partial @partial(jnp.vectorize, signature='(k),(k)>(k)') def cross_product(a, b): assert a.shape == b.shape and a.ndim == b.ndim == 1 return jnp.array([a[1] * b[2]  a[2] * b[1], a[2] * b[0]  a[0] * b[2], a[0] * b[1]  a[1] * b[0]]) @partial(jnp.vectorize, signature='(n,m),(m)>(n)') def matrix_vector_product(matrix, vector): assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape return matrix @ vector
These functions are only written to handle 1D or 2D arrays (the
assert
statements will never be violated), but with vectorize they support arbitrary dimensional inputs with NumPy style broadcasting, e.g.,>>> cross_product(jnp.ones(3), jnp.ones(3)).shape (3,) >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2, 3) >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape (2, 2, 3) >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) ValueError: input with shape (3,) does not have enough dimensions for all core dimensions ('n', 'k') on vectorized function with excluded=frozenset() and signature='(n,k),(k)>(k)' >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2,) >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape (4, 2) # not the same as jnp.matmul

jax.
pmap
(fun, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None)[source]¶ Parallel map with support for collective operations.
The purpose of
pmap()
is to express singleprogram multipledata (SPMD) programs. Applyingpmap()
to a function will compile the function with XLA (similarly tojit()
), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable tovmap()
because both transformations map a function over array axes, but wherevmap()
vectorizes functions by pushing the mapped axis down into primitive operations,pmap()
instead replicates the function and executes each replica on its own XLA device in parallel.Another key difference with
vmap()
is that whilevmap()
can only express pure maps,pmap()
enables the use of parallel SPMD collective operations, like allreduce sum.The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by
jax.local_device_count()
(unlessdevices
is specified, see below). For nestedpmap()
calls, the product of the mapped axis sizes must be less than or equal to the number of XLA devices.Multihost platforms: On multihost platforms such as TPU pods,
pmap()
is designed to be used in SPMD Python programs, where every host is running the same Python code such that all hosts run the same pmapped function in the same order. Each host should still call the pmapped function with mapped axis size equal to the number of local devices (unlessdevices
is specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations infun
will be computed over all participating devices, including those on other hosts, via devicetodevice communication. Conceptually, this can be thought of as running a pmap over a single array sharded across hosts, where each host “sees” only its local shard of the input and output. The SPMD model requires that the same multihost pmaps must be run in the same order on all devices, but they can be interspersed with arbitrary operations running on a single host. Parameters
fun (
Callable
[…, ~T]) – Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated bystatic_broadcasted_argnums
can be anything at all, provided they are hashable and have an equality operation defined.axis_name (
Optional
[Any
]) – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.in_axes – A nonnegative integer, None, or nested Python container thereof that specifies which axes in the input to map over (see
vmap()
).out_axes – A nonnegative integer, None, or nested Python container thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a nonNone
out_axes
specification (seevmap()
).static_broadcasted_argnums (
Union
[int
,Iterable
[int
]]) – An int or collection of ints specifying which positional arguments to treat as static (compiletime constant). Operations that only depend on static arguments will be constantfolded. Calling the pmapped function with different values for these constants will trigger recompilation. If the pmapped function is called with fewer positional arguments than indicated bystatic_argnums
then an error is raised. Each of the static arguments will be broadcasted to all devices. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().devices – This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). If specified, the size of the mapped axis must be equal to the number of local devices in the sequence. Nested
pmap()
s withdevices
specified in either the inner or outerpmap()
are not yet supported.backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend. ‘cpu’, ‘gpu’, or ‘tpu’.axis_size (
Optional
[int
]) – Optional; the size of the mapped axis.donate_argnums (
Union
[int
,Iterable
[int
]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.global_arg_shapes (
Optional
[Tuple
[Tuple
[int
, …], …]]) – Optional, must be set when using pmap(sharded_jit) and the partitioned values span multiple processes. The global crossprocess perreplica shape of each argument, i.e. does not include the leading pmapped dimension. Can be None for replicated arguments. This API is likely to change in the future.
 Return type
Callable
[…, ~T] Returns
A parallelized version of
fun
with arguments that correspond to those offun
but with extra array axes at positions indicated byin_axes
and with output that has an additional leading array axis (with the same size).
For example, assuming 8 XLA devices are available,
pmap()
can be used as a map along a leading array axis:>>> import jax.numpy as jnp >>> >>> out = pmap(lambda x: x ** 2)(jnp.arange(8)) >>> print(out) [0, 1, 4, 9, 16, 25, 36, 49]
When the leading dimension is smaller than the number of available devices JAX will simply run on a subset of devices:
>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) >>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 >>> out = pmap(jnp.dot)(x, y) >>> print(out) [[[ 4. 9.] [ 12. 29.]] [[ 244. 345.] [ 348. 493.]] [[ 1412. 1737.] [ 1740. 2141.]]]
If your leading dimension is larger than the number of available devices you will get an error:
>>> pmap(lambda x: x ** 2)(jnp.arange(9)) ValueError: ... requires 9 replicas, but only 8 XLA devices are available
As with
vmap()
, usingNone
inin_axes
indicates that an argument doesn’t have an extra axis and should be broadcasted, rather than mapped, across the replicas:>>> x, y = jnp.arange(2.), 4. >>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y) >>> print(out) ([4., 5.], [8., 8.])
Note that
pmap()
always returns values mapped over their leading axis, equivalent to usingout_axes=0
invmap()
.In addition to expressing pure maps,
pmap()
can also be used to express parallel singleprogram multipledata (SPMD) programs that communicate via collective operations. For example:>>> f = lambda x: x / jax.lax.psum(x, axis_name='i') >>> out = pmap(f, axis_name='i')(jnp.arange(4.)) >>> print(out) [ 0. 0.16666667 0.33333334 0.5 ] >>> print(out.sum()) 1.0
In this example,
axis_name
is a string, but it can be any Python object with__hash__
and__eq__
defined.The argument
axis_name
topmap()
names the mapped axis so that collective operations, likejax.lax.psum()
, can refer to it. Axis names are important particularly in the case of nestedpmap()
functions, where collective operations can operate over distinct axes:>>> from functools import partial >>> import jax >>> >>> @partial(pmap, axis_name='rows') ... @partial(pmap, axis_name='cols') ... def normalize(x): ... row_normed = x / jax.lax.psum(x, 'rows') ... col_normed = x / jax.lax.psum(x, 'cols') ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) ... return row_normed, col_normed, doubly_normed >>> >>> x = jnp.arange(8.).reshape((4, 2)) >>> row_normed, col_normed, doubly_normed = normalize(x) >>> print(row_normed.sum(0)) [ 1. 1.] >>> print(col_normed.sum(1)) [ 1. 1. 1. 1.] >>> print(doubly_normed.sum((0, 1))) 1.0
On multihost platforms, collective operations operate over all devices, including those on other hosts. For example, assuming the following code runs on two hosts with 4 XLA devices each:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i') >>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4,8) >>> out = pmap(f, axis_name='i')(data) >>> print(out) [28 29 30 31] # on host 0 [32 33 34 35] # on host 1
Each host passes in a different length4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length4 arrays can be thought of as a sharded length8 array (in this example equivalent to jnp.arange(8)) that is mapped over, with the length8 mapped axis given name ‘i’. The pmap call on each host then returns the corresponding length4 output shard.
The
devices
argument can be used to specify exactly which devices are used to run the parallel computation. For example, again assuming a single host with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two:>>> from functools import partial >>> @partial(pmap, axis_name='i', devices=jax.devices()[:6]) ... def f1(x): ... return x / jax.lax.psum(x, axis_name='i') >>> >>> @partial(pmap, axis_name='i', devices=jax.devices()[2:]) ... def f2(x): ... return jax.lax.psum(x ** 2, axis_name='i') >>> >>> print(f1(jnp.arange(6.))) [0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333] >>> print(f2(jnp.array([2., 3.]))) [ 13. 13.]

jax.
devices
(backend=None)[source]¶ Returns a list of all devices for a given backend.
Each device is represented by a subclass of
Device
(e.g.CpuDevice
,GpuDevice
). The length of the returned list is equal todevice_count(backend)
. Local devices can be identified by comparingDevice.host_id()
to the value returned byjax.host_id()
.If
backend
isNone
, returns all the devices from the default backend. The default backend is generally'gpu'
or'tpu'
if available, otherwise'cpu'
.

jax.
local_devices
(host_id=None, backend=None)[source]¶ Like
jax.devices()
, but only returns devices local to a given host.If
host_id
isNone
, returns devices local to this host. Parameters
host_id (
Optional
[int
]) – the integer ID of the host. Host IDs can be retrieved viajax.host_ids()
.backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend:'cpu'
,'gpu'
, or'tpu'
.
 Returns
List of Device subclasses.

jax.
host_id
(backend=None)[source]¶ Returns the integer host ID of this host.
On most platforms, this will always be 0. This will vary on multihost platforms though.

jax.
device_count
(backend=None)[source]¶ Returns the total number of devices.
On most platforms, this is the same as
jax.local_device_count()
. However, on multihost platforms, this will return the total number of devices across all hosts.