Public API: jax package¶
Justintime compilation (jit
)¶

Sets up 
Context manager that disables 


Creates a function that produces its XLA computation given example args. 

Creates a function that produces its jaxpr given example args. 

Compute the shape/dtype of 

Transfers 

Transfer array(s) to each specified device and form ShardedDeviceArray(s). 

Transfer array shards to specified devices and form ShardedDeviceArray(s). 

Transfer 
Returns the platform name of the default XLA backend. 


Adds a user specified name to a function when staging out JAX computations. 
Automatic differentiation¶

Creates a function that evaluates the gradient of 

Create a function that evaluates both 

Jacobian of 

Jacobian of 

Hessian of 

Computes a (forwardmode) Jacobianvector product of 

Produces a linear approximation to 

Transpose a function that is promised to be linear. 

Compute a (reversemode) vectorJacobian product of 

Set up a JAXtransformable function for a custom JVP rule definition. 

Set up a JAXtransformable function for a custom VJP rule definition. 

Closure conversion utility, for use with higherorder custom derivatives. 

Make 
Vectorization (vmap
)¶

Vectorizing map. 

Define a vectorized function with broadcasting. 
Parallelization (pmap
)¶

Parallel map with support for collective operations. 

Returns a list of all devices for a given backend. 

Like 

Returns the integer process index of this process. 

Returns the total number of devices. 

Returns the number of devices addressable by this process. 

Returns the number of JAX processes associated with the backend. 

jax.
jit
(fun, *, static_argnums=None, static_argnames=None, device=None, backend=None, donate_argnums=(), inline=False)[source]¶ Sets up
fun
for justintime compilation with XLA. Parameters
fun (~F) – Function to be jitted. Should be a pure function, as sideeffects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by
static_argnums
can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.static_argnums (
Union
[int
,Iterable
[int
],None
]) –An optional int or collection of ints that specify which positional arguments to treat as static (compiletime constant). Operations that only depend on static arguments will be constantfolded in Python (during tracing), and so the corresponding argument values can be any Python object.
Static arguments should be hashable, meaning both
__hash__
and__eq__
are implemented, and immutable. Calling the jitted function with different values for these constants will trigger recompilation. Arguments that are not arrays or containers thereof must be marked as static.If neither
static_argnums
norstatic_argnames
is provided, no arguments are treated as static. Ifstatic_argnums
is not provided butstatic_argnames
is, or vice versa, JAX usesinspect.signature(fun)
to find any positional arguments that correspond tostatic_argnames
(or vice versa). If bothstatic_argnums
andstatic_argnames
are provided,inspect.signature
is not used, and only actual parameters listed in eitherstatic_argnums
orstatic_argnames
will be treated as static.static_argnames (
Union
[str
,Iterable
[str
],None
]) – An optional string or collection of strings specifying which named arguments to treat as static (compiletime constant). See the comment onstatic_argnums
for details. If not provided butstatic_argnums
is set, the default is based on callinginspect.signature(fun)
to find corresponding named arguments.device (
Optional
[Device
]) – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved viajax.devices()
.) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0]
.backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:'cpu'
,'gpu'
, or'tpu'
.donate_argnums (
Union
[int
,Iterable
[int
]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to. By default, no arguments are donated.inline (
bool
) – Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False.
 Return type
~F
 Returns
A wrapped version of
fun
, set up for justintime compilation.
In the following example,
selu
can be compiled into a single fused kernel by XLA:>>> import jax >>> >>> @jax.jit ... def selu(x, alpha=1.67, lmbda=1.05): ... return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x)  alpha) >>> >>> key = jax.random.PRNGKey(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [0.54485 0.27744 0.29255 0.91421 0.62452 0.24748 0.85743 0.78232 0.76827 0.59566 ]

jax.
disable_jit
()[source]¶ Context manager that disables
jit()
behavior under its dynamic context.For debugging it is useful to have a mechanism that disables
jit()
everywhere in a dynamic context.Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a
ShapedArray
instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign sideeffecting operation in a jitted function, like a print:>>> import jax >>> >>> @jax.jit ... def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) Value of y is Traced<ShapedArray(int32[3])>with<DynamicJaxprTrace(level=0/1)> [5 7 9]
Here
y
has been abstracted byjit()
to aShapedArray
, which represents an array with a fixed shape and type but an arbitrary value. The value ofy
is also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use thedisable_jit()
context manager:>>> import jax >>> >>> with jax.disable_jit(): ... print(f(jax.numpy.array([1, 2, 3]))) ... Value of y is [2 4 6] [5 7 9]

jax.
xla_computation
(fun, static_argnums=(), axis_env=None, in_parts=None, out_parts=None, backend=None, tuple_args=False, instantiate_const_outputs=None, return_shape=False, donate_argnums=())[source]¶ Creates a function that produces its XLA computation given example args.
 Parameters
fun (
Callable
) – Function from which to form XLA computations.static_argnums (
Union
[int
,Iterable
[int
]]) – See thejax.jit()
docstring.axis_env (
Optional
[Sequence
[Tuple
[Any
,int
]]]) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications ofjax.pmap()
. See the examples below.in_parts – Optional, how each argument to
fun
should be partitioned or replicated. This is used to specify partitioned XLA computations, seesharded_jit
for more info.out_parts – Optional, how each output of
fun
should be partitioned or replicated. This is used to specify partitioned XLA computations, seesharded_jit
for more info.backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend:'cpu'
,'gpu'
, or'tpu'
.tuple_args (
bool
) – Optional bool, defaults toFalse
. IfTrue
, the resulting XLA computation will have a single tuple argument that is unpacked into the specified function arguments. If None, tupling will be enabled when there are more than 100 arguments, since some platforms have limits on argument arity.instantiate_const_outputs (
Optional
[bool
]) – Deprecated argument, does nothing.return_shape (
bool
) – Optional boolean, defaults toFalse
. IfTrue
, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output offun
and where the leaves are objects withshape
,dtype
, andnamed_shape
attributes representing the corresponding types of the output leaves.donate_argnums (
Union
[int
,Iterable
[int
]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.
 Return type
 Returns
A wrapped version of
fun
that when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods likeas_hlo_text
,as_serialized_hlo_module_proto
, andas_hlo_dot_graph
. If the argumentreturn_shape
isTrue
, then the wrapped function returns a pair where the first element is the XLA Computation and the second element is a pytree representing the structure, shapes, dtypes, and named shapes of the output offun
.Concrete example arguments are not always necessary. For those arguments not indicated by
static_argnums
, any object withshape
anddtype
attributes is acceptable (excepting namedtuples, which are treated as Python containers).
For example:
>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> c = jax.xla_computation(f)(3.) >>> print(c.as_hlo_text()) HloModule xla_computation_f.6 ENTRY xla_computation_f.6 { constant.2 = pred[] constant(false) parameter.1 = f32[] parameter(0) cosine.3 = f32[] cosine(parameter.1) sine.4 = f32[] sine(cosine.3) ROOT tuple.5 = (f32[]) tuple(sine.4) }
Alternatively, the assignment to
c
above could be written:>>> import types >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> c = jax.xla_computation(f)(scalar)
Here’s an example that involves a parallel collective and axis name:
>>> def f(x): return x  jax.lax.psum(x, 'i') >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) >>> print(c.as_hlo_text()) HloModule jaxpr_computation.9 primitive_computation.3 { parameter.4 = s32[] parameter(0) parameter.5 = s32[] parameter(1) ROOT add.6 = s32[] add(parameter.4, parameter.5) } ENTRY jaxpr_computation.9 { tuple.1 = () tuple() parameter.2 = s32[] parameter(0) allreduce.7 = s32[] allreduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 ROOT subtract.8 = s32[] subtract(parameter.2, allreduce.7) }
Notice the
replica_groups
that were generated. Here’s an example that generates more interestingreplica_groups
:>>> from jax import lax >>> def g(x): ... rowsum = lax.psum(x, 'i') ... colsum = lax.psum(x, 'j') ... allsum = lax.psum(x, ('i', 'j')) ... return rowsum, colsum, allsum ... >>> axis_env = [('i', 4), ('j', 2)] >>> c = xla_computation(g, axis_env=axis_env)(5.) >>> print(c.as_hlo_text()) HloModule jaxpr_computation__1.19 [removed uninteresting text here] ENTRY jaxpr_computation__1.19 { tuple.1 = () tuple() parameter.2 = f32[] parameter(0) allreduce.7 = f32[] allreduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 allreduce.12 = f32[] allreduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 allreduce.17 = f32[] allreduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 ROOT tuple.18 = (f32[], f32[], f32[]) tuple(allreduce.7, allreduce.12, allreduce.17) }

jax.
make_jaxpr
(fun, static_argnums=(), axis_env=None, return_shape=False)[source]¶ Creates a function that produces its jaxpr given example args.
 Parameters
fun (
Callable
) – The function whosejaxpr
is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.static_argnums (
Union
[int
,Iterable
[int
]]) – See thejax.jit()
docstring.axis_env (
Optional
[Sequence
[Tuple
[Any
,int
]]]) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications ofjax.pmap()
.return_shape (
bool
) – Optional boolean, defaults toFalse
. IfTrue
, the wrapped function returns a pair where the first element is the XLA computation and the second element is a pytree with the same structure as the output offun
and where the leaves are objects withshape
,dtype
, andnamed_shape
attributes representing the corresponding types of the output leaves.
 Return type
Callable
[…,ClosedJaxpr
] Returns
A wrapped version of
fun
that when applied to example arguments returns aClosedJaxpr
representation offun
on those arguments. If the argumentreturn_shape
isTrue
, then the returned function instead returns a pair where the first element is theClosedJaxpr
representation offun
and the second element is a pytree representing the structure, shape, dtypes, and named shapes of the output offun
.
A
jaxpr
is JAX’s intermediate representation for program traces. Thejaxpr
language is based on the simplytyped firstorder lambda calculus with letbindings.make_jaxpr()
adapts a function to return itsjaxpr
, which we can inspect to understand what JAX is doing internally. Thejaxpr
returned is a trace offun
abstracted toShapedArray
level. Other levels of abstraction exist internally.We do not describe the semantics of the
jaxpr
language in detail here, but instead give a few examples.>>> import jax >>> >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) 0.83602 >>> jax.make_jaxpr(f)(3.0) { lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; a:f32[]. let b:f32[] = cos a c:f32[] = sin a _:f32[] = sin b d:f32[] = cos b e:f32[] = mul 1.0 d f:f32[] = neg e g:f32[] = mul f c in (g,) }

jax.
eval_shape
(fun, *args, **kwargs)[source]¶ Compute the shape/dtype of
fun
without any FLOPs.This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) return jax.tree_util.tree_map(shape_dtype_struct, out) def shape_dtype_struct(x): return ShapeDtypeStruct(x.shape, x.dtype) class ShapeDtypeStruct: __slots__ = ["shape", "dtype"] def __init__(self, shape, dtype): self.shape = shape self.dtype = dtype
In particular, the output is a pytree of objects that have
shape
anddtype
attributes, but nothing else about them is guaranteed by the API.But instead of applying
fun
directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.Using
eval_shape()
can also catch shape errors, and will raise same shape errors as evaluatingfun(*args, **kwargs)
. Parameters
fun (
Callable
) – The function whose output shape should be evaluated.*args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the
shape
anddtype
attributes are accessed, only values that ducktype arrays are required, rather than real ndarrays. The ducktyped objects cannot be namedtuples because those are treated as standard Python containers. See the example below.**kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in
args
, array values need only be ducktyped to haveshape
anddtype
attributes.
For example:
>>> import jax >>> import jax.numpy as jnp >>> >>> f = lambda A, x: jnp.tanh(jnp.dot(A, x)) >>> class MyArgArray(object): ... def __init__(self, shape, dtype): ... self.shape = shape ... self.dtype = jnp.dtype(dtype) ... >>> A = MyArgArray((2000, 3000), jnp.float32) >>> x = MyArgArray((3000, 1000), jnp.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) float32

jax.
device_put
(x, device=None)[source]¶ Transfers
x
todevice
. Parameters
If the
device
parameter isNone
, then this operation behaves like the identity function if the operand is on any device already, otherwise it transfers the data to the default device, uncommitted.For more details on data placement see the FAQ on data placement.
 Returns
A copy of
x
that resides ondevice
.

jax.
device_put_replicated
(x, devices)[source]¶ Transfer array(s) to each specified device and form ShardedDeviceArray(s).
 Parameters
 Returns
A ShardedDeviceArray or (nested) Python container thereof representing the value of
x
broadcasted along a new leading axis of sizelen(devices)
, with each slice along that new leading axis backed by memory on the device specified by the corresponding entry indevices
.
Examples
Passing an array:
>>> import jax >>> devices = jax.local_devices() >>> x = jax.numpy.array([1., 2., 3.]) >>> y = jax.device_put_replicated(x, devices) >>> np.allclose(y, jax.numpy.stack([x for _ in devices])) True
See also
device_put
device_put_sharded

jax.
device_put_sharded
(shards, devices)[source]¶ Transfer array shards to specified devices and form ShardedDeviceArray(s).
 Parameters
shards (
Sequence
[Any
]) – A sequence of arrays, scalars, or (nested) standard Python containers thereof representing the shards to be stacked together to form the output. The length ofshards
must equal the length ofdevices
.devices (
Sequence
[Device
]) – A sequence ofDevice
instances representing the devices to which corresponding shards inshards
will be transferred.
 Returns
A ShardedDeviceArray or (nested) Python container thereof representing the elements of
shards
stacked together, with each shard backed by physical device memory specified by the corresponding entry indevices
.
Examples
Passing a list of arrays for
shards
results in a sharded array containing a stacked version of the inputs:>>> import jax >>> devices = jax.local_devices() >>> x = [jax.numpy.ones(5) for device in devices] >>> y = jax.device_put_sharded(x, devices) >>> np.allclose(y, jax.numpy.stack(x)) True
Passing a list of nested container objects with arrays at the leaves for
shards
corresponds to stacking the shards at each leaf. This requires all entries in the list to have the same tree structure:>>> x = [(i, jax.numpy.arange(i, i + 4)) for i in range(len(devices))] >>> y = jax.device_put_sharded(x, devices) >>> type(y) <class 'tuple'> >>> y0 = jax.device_put_sharded([a for a, b in x], devices) >>> y1 = jax.device_put_sharded([b for a, b in x], devices) >>> np.allclose(y[0], y0) True >>> np.allclose(y[1], y1) True
See also
device_put
device_put_replicated

jax.
device_get
(x)[source]¶ Transfer
x
to host. Parameters
x (
Any
) – An array, scalar, DeviceArray or (nested) standard Python container thereof representing the array to be transferred to host. Returns
An array or (nested) Python container thereof representing the value of
x
.
Examples
Passing a DeviceArray:
>>> import jax >>> x = jax.numpy.array([1., 2., 3.]) >>> jax.device_get(x) array([1., 2., 3.], dtype=float32)
Passing a scalar (has no effect):
>>> jax.device_get(1) 1
See also
device_put
device_put_sharded
device_put_replicated

jax.
named_call
(fun, *, name=None)[source]¶ Adds a user specified name to a function when staging out JAX computations.
When staging out computations for justintime compilation to XLA (or other backends such as TensorFlow) JAX runs your Python program but by default does not preserve any of the function names or other metadata associated with it. This can make debugging the staged out (and/or compiled) representation of your program complicated because there is limited context information for each operation being executed.
named_call tells JAX to stage the given function out as a subcomputation with a specific name. When the staged out program is compiled with XLA these named subcomputations are preserved and show up in debugging utilities like the TensorFlow Profiler in TensorBoard. Names are also preserved when staging out JAX programs to TensorFlow using
experimental.jax2tf.convert()
. Parameters
 Return type
 Returns
A version of fun that is wrapped in a name_scope.

jax.
grad
(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]¶ Creates a function that evaluates the gradient of
fun
. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers. Argument arrays in the positions specified byargnums
must be of inexact (i.e., floatingpoint or complex) type. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool
) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vectorspace dtype (float0). Default False.reduce_axes (
Sequence
[Any
]) – Optional, tuple of axis names. If an axis is listed here, andfun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding gradient. Otherwise, the gradient will be perexample over named axes. For example, if'batch'
is a named batch axis,grad(f, reduce_axes=('batch',))
will create a function that computes the total gradient whilegrad(f)
will create one that computes the perexample gradient.
 Return type
 Returns
A function with the same arguments as
fun
, that evaluates the gradient offun
. Ifargnums
is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. Ifhas_aux
is True then a pair of (gradient, auxiliary_data) is returned.
For example:
>>> import jax >>> >>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043

jax.
value_and_grad
(fun, argnums=0, has_aux=False, holomorphic=False, allow_int=False, reduce_axes=())[source]¶ Create a function that evaluates both
fun
and the gradient offun
. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape()
but not arrays with shape(1,)
etc.)argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.allow_int (
bool
) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vectorspace dtype (float0). Default False.reduce_axes (
Sequence
[Any
]) – Optional, tuple of axis names. If an axis is listed here, andfun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding gradient. Otherwise, the gradient will be perexample over named axes. For example, if'batch'
is a named batch axis,value_and_grad(f, reduce_axes=('batch',))
will create a function that computes the total gradient whilevalue_and_grad(f)
will create one that computes the perexample gradient.
 Return type
 Returns
A function with the same arguments as
fun
that evaluates bothfun
and the gradient offun
and returns them as a pair (a twoelement tuple). Ifargnums
is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.

jax.
jacfwd
(fun, argnums=0, holomorphic=False)[source]¶ Jacobian of
fun
evaluated columnbycolumn using forwardmode AD. Parameters
fun (
Callable
) – Function whose Jacobian is to be computed.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.
 Return type
 Returns
A function with the same arguments as
fun
, that evaluates the Jacobian offun
using forwardmode automatic differentiation.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2  2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. 2. ] [ 1.6209 0. 0.84147]]

jax.
jacrev
(fun, argnums=0, holomorphic=False, allow_int=False)[source]¶ Jacobian of
fun
evaluated rowbyrow using reversemode AD. Parameters
fun (
Callable
) – Function whose Jacobian is to be computed.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.allow_int (
bool
) – Optional, bool. Whether to allow differentiating with respect to integer valued inputs. The gradient of an integer input will have a trivial vectorspace dtype (float0). Default False.
 Return type
 Returns
A function with the same arguments as
fun
, that evaluates the Jacobian offun
using reversemode automatic differentiation.
>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): ... return jnp.asarray( ... [x[0], 5*x[2], 4*x[1]**2  2*x[2], x[2] * jnp.sin(x[0])]) ... >>> print(jax.jacrev(f)(jnp.array([1., 2., 3.]))) [[ 1. 0. 0. ] [ 0. 0. 5. ] [ 0. 16. 2. ] [ 1.6209 0. 0.84147]]

jax.
hessian
(fun, argnums=0, holomorphic=False)[source]¶ Hessian of
fun
as a dense array. Parameters
fun (
Callable
) – Function whose Hessian is to be computed. Its arguments at positions specified byargnums
should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default0
).holomorphic (
bool
) – Optional, bool. Indicates whetherfun
is promised to be holomorphic. Default False.
 Return type
 Returns
A function with the same arguments as
fun
, that evaluates the Hessian offun
.
>>> import jax >>> >>> g = lambda x: x[0]**3  2*x[0]*x[1]  x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6. 2.] [ 2. 480.]]
hessian()
is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure ofjax.hessian(fun)(x)
is given by forming a tree product of the structure offun(x)
with a tree product of two copies of the structure ofx
. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:>>> import jax.numpy as jnp >>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])} >>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.})) {'c': {'a': {'a': DeviceArray([[[ 2., 0.], [ 0., 0.]], [[ 0., 0.], [ 0., 12.]]], dtype=float32), 'b': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32)}, 'b': {'a': DeviceArray([[[ 1. , 0. ], [ 0. , 0. ]], [[ 0. , 0. ], [ 0. , 12.317766]]], dtype=float32), 'b': DeviceArray([[[0. , 0. ], [0. , 0. ]], [[0. , 0. ], [0. , 3.843624]]], dtype=float32)}}}
Thus each leaf in the tree structure of
jax.hessian(fun)(x)
corresponds to a leaf offun(x)
and a pair of leaves ofx
. For each leaf injax.hessian(fun)(x)
, if the corresponding array leaf offun(x)
has shape(out_1, out_2, ...)
and the corresponding array leaves ofx
have shape(in_1_1, in_1_2, ...)
and(in_2_1, in_2_2, ...)
respectively, then the Hessian leaf has shape(out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...)
. In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.In particular, an array is produced (with no pytrees involved) when the function input
x
and outputfun(x)
are each a single array, as in theg
example above. Iffun(x)
has shape(out1, out2, ...)
andx
has shape(in1, in2, ...)
thenjax.hessian(fun)(x)
has shape(out1, out2, ..., in1, in2, ..., in1, in2, ...)
. To flatten pytrees into 1D vectors, consider usingjax.flatten_util.flatten_pytree()
.

jax.
jvp
(fun, primals, tangents)[source]¶ Computes a (forwardmode) Jacobianvector product of
fun
. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be either a tuple or a list of arguments, and its length should be equal to the number of positional parameters offun
.tangents – The tangent vector for which the Jacobianvector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as
primals
.
 Return type
 Returns
A
(primals_out, tangents_out)
pair, whereprimals_out
isfun(*primals)
, andtangents_out
is the Jacobianvector product offunction
evaluated atprimals
withtangents
. Thetangents_out
value has the same Python tree structure and shapes asprimals_out
.
For example:
>>> import jax >>> >>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(y) 0.09983342 >>> print(v) 0.19900084

jax.
linearize
(fun, *primals)[source]¶ Produces a linear approximation to
fun
usingjvp()
and partial eval. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.primals – The primal values at which the Jacobian of
fun
should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters offun
.
 Return type
 Returns
A pair where the first element is the value of
f(*primals)
and the second element is a function that evaluates the (forwardmode) Jacobianvector product offun
evaluated atprimals
without redoing the linearization work.
In terms of values computed,
linearize()
behaves much like a curriedjvp()
, where these two code blocks compute the same values:y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
However, the difference is that
linearize()
uses partial evaluation so that the functionf
is not relinearized on calls tof_jvp
. In general that means the memory usage scales with the size of the computation, much like in reversemode. (Indeed,linearize()
has a similar signature tovjp()
!)This function is mainly useful if you want to apply
f_jvp
multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize usingvmap()
, as in:pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using
vmap()
andjvp()
together like this we avoid the storedlinearization memory cost that scales with the depth of the computation, which is incurred by bothlinearize()
andvjp()
.Here’s a more complete example of using
linearize()
:>>> import jax >>> import jax.numpy as jnp >>> >>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (DeviceArray(3.26819, dtype=float32, weak_type=True), DeviceArray(5.00753, dtype=float32, weak_type=True)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) 5.007528 >>> print(f_jvp(4.)) 6.676704

jax.
linear_transpose
(fun, *primals, reduce_axes=())[source]¶ Transpose a function that is promised to be linear.
For linear functions, this transformation is equivalent to
vjp
, but avoids the overhead of computing the forward pass.The outputs of the transposed function will always have the exact same dtypes as
primals
, even if some values are truncated (e.g., from complex to float, or from float64 to float32). To avoid truncation, use dtypes inprimals
that match the full range of desired outputs from the transposed function. Integer dtypes are not supported. Parameters
fun (
Callable
) – the linear function to be transposed.*primals – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) of those types used for evaluating the shape/dtype of
fun(*primals)
. These arguments may be real scalars/ndarrays, but that is not required: only theshape
anddtype
attributes are accessed. See below for an example. (Note that the ducktyped objects cannot be namedtuples because those are treated as standard Python containers.)reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding cotangent. Otherwise, the transposed function will be perexample over named axes. For example, if'batch'
is a named batch axis,linear_transpose(f, *args, reduce_axes=('batch',))
will create a transpose function that sums over the batch whilelinear_transpose(f, args)
will create a perexample transpose.
 Return type
 Returns
A callable that calculates the transpose of
fun
. Valid input into this function must have the same shape/dtypes/structure as the result offun(*primals)
. Output will be a tuple, with the same shape/dtypes/structure asprimals
.
>>> import jax >>> import types >>> >>> f = lambda x, y: 0.5 * x  0.5 * y >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) >>> f_transpose = jax.linear_transpose(f, scalar, scalar) >>> f_transpose(1.0) (DeviceArray(0.5, dtype=float32), DeviceArray(0.5, dtype=float32))

jax.
vjp
(fun: Callable[[…], T], *primals: Any, has_aux: Literal[False] = 'False', reduce_axes: Sequence[Any] = '()') → Tuple[T, Callable][source]¶ 
jax.
vjp
(fun: Callable[[…], Tuple[T, U]], *primals: Any, has_aux: Literal[True], reduce_axes: Sequence[Any] = '()') → Tuple[T, Callable, U] 
jax.
vjp
(fun: Callable[[…], T], *primals: Any) → Tuple[T, Callable] 
jax.
vjp
(fun: Callable[[…], Any], *primals: Any, has_aux: bool, reduce_axes: Sequence[Any] = '()') → Union[Tuple[Any, Callable], Tuple[Any, Callable, Any]] Compute a (reversemode) vectorJacobian product of
fun
.grad()
is implemented as a special case ofvjp()
. Parameters
fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.primals – A sequence of primal values at which the Jacobian of
fun
should be evaluated. The length ofprimals
should be equal to the number of positional parameters tofun
. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.has_aux (
bool
) – Optional, bool. Indicates whetherfun
returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.reduce_axes – Optional, tuple of axis names. If an axis is listed here, and
fun
implicitly broadcasts a value over that axis, the backward pass will perform apsum
of the corresponding gradient. Otherwise, the VJP will be perexample over named axes. For example, if'batch'
is a named batch axis,vjp(f, *args, reduce_axes=('batch',))
will create a VJP function that sums over the batch whilevjp(f, *args)
will create a perexample VJP.
 Return type
 Returns
If
has_aux
isFalse
, returns a(primals_out, vjpfun)
pair, whereprimals_out
isfun(*primals)
.vjpfun
is a function from a cotangent vector with the same shape asprimals_out
to a tuple of cotangent vectors with the same shape asprimals
, representing the vectorJacobian product offun
evaluated atprimals
. Ifhas_aux
isTrue
, returns a(primals_out, vjpfun, aux)
tuple whereaux
is the auxiliary data returned byfun
.
>>> import jax >>> >>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((0.7, 0.3)) >>> print(xbar) 0.61430776 >>> print(ybar) 0.2524413

class
jax.
custom_jvp
(fun, nondiff_argnums=())[source]¶ Set up a JAXtransformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like
jax.jvp()
orjax.grad()
) is applied, in which case a custom usersupplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation.There are two instance methods available for defining the custom JVP rule:
defjvp()
for defining a single custom JVP rule for all the function’s inputs, and for conveniencedefjvps()
, which wrapsdefjvp()
, and allows you to provide separate definitions for the partial derivatives of the function w.r.t. each of its arguments.For example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out
For a more detailed introduction, see the tutorial.

defjvp
(jvp)[source]¶ Define a custom JVP rule for the function represented by this instance.
 Parameters
jvp (
Callable
[…,Tuple
[~ReturnValue, ~ReturnValue]]) – a Python callable representing the custom JVP rule. When there are nonondiff_argnums
, thejvp
function should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples are equal to the number of parameters of thecustom_jvp
function. Thejvp
function should produce as output a pair where the first element is the primal output and the second element is the tangent output. Elements of the input and output tuples may be arrays or any nested tuples/lists/dicts thereof. Return type
 Returns
None.
Example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot return primal_out, tangent_out

defjvps
(*jvps)[source]¶ Convenience wrapper for defining JVPs for each argument separately.
This convenience wrapper cannot be used together with
nondiff_argnums
. Parameters
*jvps – a sequence of functions, one for each positional argument of the
custom_jvp
function. Each function takes as arguments the tangent value for the corresponding primal input, the primal output, and the primal inputs. See the example below. Returns
None.
Example:
@jax.custom_jvp def f(x, y): return jnp.sin(x) * y f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)


class
jax.
custom_vjp
(fun, nondiff_argnums=())[source]¶ Set up a JAXtransformable function for a custom VJP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reversemode differentiation transformation (like
jax.grad()
) is applied, in which case a custom usersupplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method,defvjp()
, which may be used to define the custom VJP rule.This decorator precludes the use of forwardmode automatic differentiation.
For example:
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
For a more detailed introduction, see the tutorial.

defvjp
(fwd, bwd)[source]¶ Define a custom VJP rule for the function represented by this instance.
 Parameters
fwd (
Callable
[…,Tuple
[~ReturnValue,Any
]]) – a Python callable representing the forward pass of the custom VJP rule. When there are nonondiff_argnums
, thefwd
function has the same input signature as the underlying primal function. It should return as output a pair, where the first element represents the primal output and the second element represents any “residual” values to store from the forward pass for use on the backward pass by the functionbwd
. Input arguments and elements of the output pair may be arrays or nested tuples/lists/dicts thereof.bwd (
Callable
[…,Tuple
[Any
, …]]) – a Python callable representing the backward pass of the custom VJP rule. When there are nonondiff_argnums
, thebwd
function takes two arguments, where the first is the “residual” values produced on the forward pass byfwd
, and the second is the output cotangent with the same structure as the primal function output. The output ofbwd
must be a tuple of length equal to the number of arguments of the primal function, and the tuple elements may be arrays or nested tuples/lists/dicts thereof so as to match the structure of the primal input arguments.
 Return type
 Returns
None.
Example:
@jax.custom_vjp def f(x, y): return jnp.sin(x) * y def f_fwd(x, y): return f(x, y), (jnp.cos(x), jnp.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)


jax.
closure_convert
(fun, *example_args)[source]¶ Closure conversion utility, for use with higherorder custom derivatives.
To define custom derivatives such as with
jax.custom_vjp(f)
, the target functionf
must take, as formal arguments, all values involved in differentiation. Iff
is a higherorder function, in that it accepts as an argument a Python functiong
, then values stored away ing
’s closure will not be visible to the custom derivative rules, and attempts at AD involving these values will fail. One way around this is to convert the closure by extracting these values, and to pass them as explicit formal arguments across the custom derivative boundary. This utility carries out that conversion. More precisely, it closureconverts the functionfun
specialized to the types of the arguments given inexample_args
.When we refer here to “values in the closure” of
fun
, we do not mean the values that are captured by Python directly whenfun
is defined (e.g. the Python objects infun.__closure__
, if the attribute exists). Rather, we mean values encountered during the execution offun
onexample_args
that determine its output. This may include, for instance, arrays captured transitively in Python closures, i.e. in the Python closure of functions called byfun
, the closures of the functions that they call, and so forth.The function
fun
must be a pure function.Example usage:
def minimize(objective_fn, x0): converted_fn, aux_args = closure_convert(objective_fn, x0) return _minimize(converted_fn, x0, *aux_args) @partial(custom_vjp, nondiff_argnums=(0,)) def _minimize(objective_fn, x0, *args): z = objective_fn(x0, *args) # ... find minimizer x_opt ... return x_opt def fwd(objective_fn, x0, *args): y = _minimize(objective_fn, x0, *args) return y, (y, args) def rev(objective_fn, res, g): y, args = res y_bar = g # ... custom reversemode AD ... return x0_bar, *args_bars _minimize.defvjp(fwd, rev)
 Parameters
fun – Python callable to be converted. Must be a pure function.
example_args – Arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e., pytrees) thereof, used to determine the types of the formal arguments to
fun
. This typespecialized form offun
is the function that will be closure converted.
 Returns
A pair comprising (i) a Python callable, accepting the same arguments as
fun
followed by arguments corresponding to the values hoisted from its closure, and (ii) a list of values hoisted from the closure.

jax.
checkpoint
(fun, concrete=False, prevent_cse=True, policy=None)[source]¶ Make
fun
recompute internal linearization points when differentiated.The
jax.checkpoint()
decorator, aliased tojax.remat
, provides a way to trade off computation time and memory cost in the context of automatic differentiation, especially with reversemode autodiff likejax.grad()
andjax.vjp()
but also withjax.linearize()
.When differentiating a function in reversemode, by default all the linearization points (e.g. inputs to elementwise nonlinear primitive operations) are stored when evaluating the forward pass so that they can be reused on the backward pass. This evaluation strategy can lead to a high memory cost, or even to poor performance on hardware accelerators where memory access is much more expensive than FLOPs.
An alternative evaluation strategy is for some of the linearization points to be recomputed (i.e. rematerialized) rather than stored. This approach can reduce memory usage at the cost of increased computation.
This function decorator produces a new version of
fun
which follows the rematerialization strategy rather than the default storeeverything strategy. That is, it returns a new version offun
which, when differentiated, doesn’t store any of its intermediate linearization points. Instead, these linearization points are recomputed from the function’s saved inputs.See the examples below.
 Parameters
fun (
Callable
) – Function for which the autodiff evaluation strategy is to be changed from the default of storing all intermediate linearization points to recomputing them. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.concrete (
bool
) – Optional, boolean indicating whetherfun
may involve valuedependent Python control flow (default False). Support for such control flow is optional, and disabled by default, because in some edgecase compositions withjax.jit()
it can lead to some extra computation.prevent_cse (
bool
) – Optional, boolean indicating whether to prevent common subexpression elimination (CSE) optimizations in the HLO generated from differentiation. This CSE prevention has costs because it can foil other optimizations, and because it can incur high overheads on some backends, especially GPU. The default is True because otherwise, under ajit
orpmap
, CSE can defeat the purpose of this decorator. But in some settings, like when used inside ascan
, this CSE prevention mechanism is unnecessary, in which caseprevent_cse
can be set to False.policy (
Optional
[Callable
[…,bool
]]) – This is an experimental feature and the API is likely to change. Optional callable, one of the attributes ofjax.checkpoint_policies
, which takes as input a typelevel specification of a firstorder primitive application and returns a boolean indicating whether the corresponding output value(s) can be saved as a residual (or, if not, instead must be recomputed in the (co)tangent computation).
 Return type
 Returns
A function (callable) with the same input/output behavior as
fun
but which, when differentiated using e.g.jax.grad()
,jax.vjp()
, orjax.linearize()
, recomputes rather than stores intermediate linearization points, thus potentially saving memory at the cost of extra computation.
Here is a simple example:
>>> import jax >>> import jax.numpy as jnp
>>> @jax.checkpoint ... def g(x): ... y = jnp.sin(x) ... z = jnp.sin(y) ... return z ... >>> jax.value_and_grad(g)(2.0) (DeviceArray(0.78907233, dtype=float32, weak_type=True), DeviceArray(0.2556391, dtype=float32))
Here, the same value is produced whether or not the
jax.checkpoint()
decorator is present. When the decorator is not present, the valuesjnp.cos(2.0)
andjnp.cos(jnp.sin(2.0))
are computed on the forward pass and are stored for use in the backward pass, because they are needed on the backward pass and depend only on the primal inputs. When usingjax.checkpoint()
, the forward pass will compute only the primal outputs and only the primal inputs (2.0
) will be stored for the backward pass. At that time, the valuejnp.sin(2.0)
is recomputed, along with the valuesjnp.cos(2.0)
andjnp.cos(jnp.sin(2.0))
.While
jax.checkpoint
controls what values are stored from the forwardpass to be used on the backward pass, the total amount of memory required to evaluate a function or its VJP depends on many additional internal details of that function. Those details include which numerical primitives are used, how they’re composed, where jit and control flow primitives like scan are used, and other factors.The
jax.checkpoint()
decorator can be applied recursively to express sophisticated autodiff rematerialization strategies. For example:>>> def recursive_checkpoint(funs): ... if len(funs) == 1: ... return funs[0] ... elif len(funs) == 2: ... f1, f2 = funs ... return lambda x: f1(f2(x)) ... else: ... f1 = recursive_checkpoint(funs[:len(funs)//2]) ... f2 = recursive_checkpoint(funs[len(funs)//2:]) ... return lambda x: f1(jax.checkpoint(f2)(x)) ...

jax.
vmap
(fun, in_axes=0, out_axes=0, axis_name=None)[source]¶ Vectorizing map. Creates a function which maps
fun
over argument axes. Parameters
fun (~F) – Function to be mapped over additional axes.
in_axes –
An integer, None, or (nested) standard Python container (tuple/list/dict) thereof specifying which input array axes to map over.
If each positional argument to
fun
is an array, thenin_axes
can be an integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments tofun
. An integer orNone
indicates which array axis to map over for all arguments (withNone
indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. Axis integers must be in the range[ndim, ndim)
for each array, wherendim
is the number of dimensions (axes) of the corresponding input array.If the positional arguments to
fun
are container types, the corresponding element ofin_axes
can itself be a matching container, so that distinct array axes can be mapped for different container elements.in_axes
must be a container tree prefix of the positional argument tuple passed tofun
.At least one positional argument must have
in_axes
not None. The sizes of the mapped input axes for all mapped positional arguments must all be equal.Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0).
See below for examples.
out_axes – An integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a nonNone
out_axes
specification. Axis integers must be in the range[ndim, ndim)
for each output array, wherendim
is the number of dimensions (axes) of the array returned by thevmap()
ed function, which is one more than the number of dimensions (axes) of the corresponding array returned byfun
.axis_name – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.
 Return type
~F
 Returns
Batched/vectorized version of
fun
with arguments that correspond to those offun
, but with extra array axes at positions indicated byin_axes
, and a return value that corresponds to that offun
, but with extra array axes at positions indicated byout_axes
.
For example, we can implement a matrixmatrix product using a vector dot product:
>>> import jax.numpy as jnp >>> >>> vv = lambda x, y: jnp.vdot(x, y) # ([a], [a]) > [] >>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) > [b] (b is the mapped axis) >>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) > [b,c] (c is the mapped axis)
Here we use
[a,b]
to indicate an array with shape (a,b). Here are some variants:>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) > [b] (b is the mapped axis) >>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) > [b] (b is the mapped axis) >>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) > [c,b] (c is the mapped axis)
Here’s an example of using container types in
in_axes
to specify which axes of the container elements to map over:>>> A, B, C, D = 2, 3, 4, 5 >>> x = jnp.ones((A, B)) >>> y = jnp.ones((B, C)) >>> z = jnp.ones((C, D)) >>> def foo(tree_arg): ... x, (y, z) = tree_arg ... return jnp.dot(x, jnp.dot(y, z)) >>> tree = (x, (y, z)) >>> print(foo(tree)) [[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]] >>> from jax import vmap >>> K = 6 # batch size >>> x = jnp.ones((K, A, B)) # batch axis in different locations >>> y = jnp.ones((B, K, C)) >>> z = jnp.ones((C, D, K)) >>> tree = (x, (y, z)) >>> vfoo = vmap(foo, in_axes=((0, (1, 2)),)) >>> print(vfoo(tree).shape) (6, 2, 5)
Here’s another example using container types in
in_axes
, this time a dictionary, to specify the elements of the container to map over:>>> dct = {'a': 0., 'b': jnp.arange(5.)} >>> x = 1. >>> def foo(dct, x): ... return dct['a'] + dct['b'] + x >>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x) >>> print(out) [1. 2. 3. 4. 5.]
The results of a vectorized function can be mapped or unmapped. For example, the function below returns a pair with the first element mapped and the second unmapped. Only for unmapped results we can specify
out_axes
to beNone
(to keep it unmapped).>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.)) (DeviceArray([4., 5.], dtype=float32), 8.0)
If the
out_axes
is specified for an unmapped result, the result is broadcast across the mapped axis:>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.)) (DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32, weak_type=True))
If the
out_axes
is specified for a mapped result, the result is transposed accordingly.Finally, here’s an example using
axis_name
together with collectives:>>> xs = jnp.arange(3. * 4.).reshape(3, 4) >>> print(vmap(lambda x: lax.psum(x, 'i'), axis_name='i')(xs)) [[12. 15. 18. 21.] [12. 15. 18. 21.] [12. 15. 18. 21.]]
See the
jax.pmap()
docstring for more examples involving collectives.

jax.numpy.
vectorize
(pyfunc, *, excluded=frozenset({}), signature=None)[source] Define a vectorized function with broadcasting.
vectorize()
is a convenience wrapper for defining vectorized functions with broadcasting, in the style of NumPy’s generalized universal functions. It allows for defining functions that are automatically repeated across any leading dimensions, without the implementation of the function needing to be concerned about how to handle higher dimensional inputs.jax.numpy.vectorize()
has the same interface asnumpy.vectorize
, but it is syntactic sugar for an autobatching transformation (vmap()
) rather than a Python loop. This should be considerably more efficient, but the implementation must be written in terms of functions that act on JAX arrays. Parameters
pyfunc – function to vectorize.
excluded – optional set of integers representing positional arguments for which the function will not be vectorized. These will be passed directly to
pyfunc
unmodified.signature – optional generalized universal function signature, e.g.,
(m,n),(n)>(m)
for vectorized matrixvector multiplication. If provided,pyfunc
will be called with (and expected to return) arrays with shapes given by the size of corresponding core dimensions. By default, pyfunc is assumed to take scalars arrays as input and output.
 Returns
Vectorized version of the given function.
Here are a few examples of how one could write vectorized linear algebra routines using
vectorize()
:>>> from functools import partial
>>> @partial(jnp.vectorize, signature='(k),(k)>(k)') ... def cross_product(a, b): ... assert a.shape == b.shape and a.ndim == b.ndim == 1 ... return jnp.array([a[1] * b[2]  a[2] * b[1], ... a[2] * b[0]  a[0] * b[2], ... a[0] * b[1]  a[1] * b[0]])
>>> @partial(jnp.vectorize, signature='(n,m),(m)>(n)') ... def matrix_vector_product(matrix, vector): ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape ... return matrix @ vector
These functions are only written to handle 1D or 2D arrays (the
assert
statements will never be violated), but with vectorize they support arbitrary dimensional inputs with NumPy style broadcasting, e.g.,>>> cross_product(jnp.ones(3), jnp.ones(3)).shape (3,) >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2, 3) >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape (2, 2, 3) >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) Traceback (most recent call last): ValueError: input with shape (3,) does not have enough dimensions for all core dimensions ('n', 'k') on vectorized function with excluded=frozenset() and signature='(n,k),(k)>(k)' >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape (2,) >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape (4, 2)
Note that this has different semantics than jnp.matmul:
>>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) Traceback (most recent call last): TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4].

jax.
pmap
(fun, axis_name=None, *, in_axes=0, out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None)[source]¶ Parallel map with support for collective operations.
The purpose of
pmap()
is to express singleprogram multipledata (SPMD) programs. Applyingpmap()
to a function will compile the function with XLA (similarly tojit()
), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable tovmap()
because both transformations map a function over array axes, but wherevmap()
vectorizes functions by pushing the mapped axis down into primitive operations,pmap()
instead replicates the function and executes each replica on its own XLA device in parallel.The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by
jax.local_device_count()
(unlessdevices
is specified, see below). For nestedpmap()
calls, the product of the mapped axis sizes must be less than or equal to the number of XLA devices.Multiprocess platforms: On multiprocess platforms such as TPU pods,
pmap()
is designed to be used in SPMD Python programs, where every process is running the same Python code such that all processes run the same pmapped function in the same order. Each process should still call the pmapped function with mapped axis size equal to the number of local devices (unlessdevices
is specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations infun
will be computed over all participating devices, including those on other processes, via devicetodevice communication. Conceptually, this can be thought of as running a pmap over a single array sharded across processes, where each process “sees” only its local shard of the input and output. The SPMD model requires that the same multiprocess pmaps must be run in the same order on all devices, but they can be interspersed with arbitrary operations running in a single process. Parameters
fun (~F) – Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by
static_broadcasted_argnums
can be anything at all, provided they are hashable and have an equality operation defined.axis_name (
Optional
[Any
]) – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.in_axes – A nonnegative integer, None, or nested Python container thereof that specifies which axes of positional arguments to map over. Arguments passed as keywords are always mapped over their leading axis (i.e. axis index 0). See
vmap()
for details.out_axes – A nonnegative integer, None, or nested Python container thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a nonNone
out_axes
specification (seevmap()
).static_broadcasted_argnums (
Union
[int
,Iterable
[int
]]) –An int or collection of ints specifying which positional arguments to treat as static (compiletime constant). Operations that only depend on static arguments will be constantfolded. Calling the pmapped function with different values for these constants will trigger recompilation. If the pmapped function is called with fewer positional arguments than indicated by
static_argnums
then an error is raised. Each of the static arguments will be broadcasted to all devices. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().Static arguments must be hashable, meaning both
__hash__
and__eq__
are implemented, and should be immutable.devices (
Optional
[Sequence
[Device
]]) – This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). Must be given identically for each process in multiprocess settings (and will therefore include devices across processes). If specified, the size of the mapped axis must be equal to the number of devices in the sequence local to the given process. Nestedpmap()
s withdevices
specified in either the inner or outerpmap()
are not yet supported.backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend. ‘cpu’, ‘gpu’, or ‘tpu’.axis_size (
Optional
[int
]) – Optional; the size of the mapped axis.donate_argnums (
Union
[int
,Iterable
[int
]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perform a computation, for example recycling one of your input buffers to store a result. You should not reuse buffers that you donate to a computation, JAX will raise an error if you try to.global_arg_shapes (
Optional
[Tuple
[Tuple
[int
, …], …]]) – Optional, must be set when using pmap(sharded_jit) and the partitioned values span multiple processes. The global crossprocess perreplica shape of each argument, i.e. does not include the leading pmapped dimension. Can be None for replicated arguments. This API is likely to change in the future.
 Return type
~F
 Returns
A parallelized version of
fun
with arguments that correspond to those offun
but with extra array axes at positions indicated byin_axes
and with output that has an additional leading array axis (with the same size).
For example, assuming 8 XLA devices are available,
pmap()
can be used as a map along a leading array axis:>>> import jax.numpy as jnp >>> >>> out = pmap(lambda x: x ** 2)(jnp.arange(8)) >>> print(out) [0, 1, 4, 9, 16, 25, 36, 49]
When the leading dimension is smaller than the number of available devices JAX will simply run on a subset of devices:
>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) >>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 >>> out = pmap(jnp.dot)(x, y) >>> print(out) [[[ 4. 9.] [ 12. 29.]] [[ 244. 345.] [ 348. 493.]] [[ 1412. 1737.] [ 1740. 2141.]]]
If your leading dimension is larger than the number of available devices you will get an error:
>>> pmap(lambda x: x ** 2)(jnp.arange(9)) ValueError: ... requires 9 replicas, but only 8 XLA devices are available
As with
vmap()
, usingNone
inin_axes
indicates that an argument doesn’t have an extra axis and should be broadcasted, rather than mapped, across the replicas:>>> x, y = jnp.arange(2.), 4. >>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y) >>> print(out) ([4., 5.], [8., 8.])
Note that
pmap()
always returns values mapped over their leading axis, equivalent to usingout_axes=0
invmap()
.In addition to expressing pure maps,
pmap()
can also be used to express parallel singleprogram multipledata (SPMD) programs that communicate via collective operations. For example:>>> f = lambda x: x / jax.lax.psum(x, axis_name='i') >>> out = pmap(f, axis_name='i')(jnp.arange(4.)) >>> print(out) [ 0. 0.16666667 0.33333334 0.5 ] >>> print(out.sum()) 1.0
In this example,
axis_name
is a string, but it can be any Python object with__hash__
and__eq__
defined.The argument
axis_name
topmap()
names the mapped axis so that collective operations, likejax.lax.psum()
, can refer to it. Axis names are important particularly in the case of nestedpmap()
functions, where collective operations can operate over distinct axes:>>> from functools import partial >>> import jax >>> >>> @partial(pmap, axis_name='rows') ... @partial(pmap, axis_name='cols') ... def normalize(x): ... row_normed = x / jax.lax.psum(x, 'rows') ... col_normed = x / jax.lax.psum(x, 'cols') ... doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) ... return row_normed, col_normed, doubly_normed >>> >>> x = jnp.arange(8.).reshape((4, 2)) >>> row_normed, col_normed, doubly_normed = normalize(x) >>> print(row_normed.sum(0)) [ 1. 1.] >>> print(col_normed.sum(1)) [ 1. 1. 1. 1.] >>> print(doubly_normed.sum((0, 1))) 1.0
On multiprocess platforms, collective operations operate over all devices, including those on other processes. For example, assuming the following code runs on two processes with 4 XLA devices each:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i') >>> data = jnp.arange(4) if jax.process_index() == 0 else jnp.arange(4, 8) >>> out = pmap(f, axis_name='i')(data) >>> print(out) [28 29 30 31] # on process 0 [32 33 34 35] # on process 1
Each process passes in a different length4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length4 arrays can be thought of as a sharded length8 array (in this example equivalent to jnp.arange(8)) that is mapped over, with the length8 mapped axis given name ‘i’. The pmap call on each process then returns the corresponding length4 output shard.
The
devices
argument can be used to specify exactly which devices are used to run the parallel computation. For example, again assuming a single process with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two:>>> from functools import partial >>> @partial(pmap, axis_name='i', devices=jax.devices()[:6]) ... def f1(x): ... return x / jax.lax.psum(x, axis_name='i') >>> >>> @partial(pmap, axis_name='i', devices=jax.devices()[2:]) ... def f2(x): ... return jax.lax.psum(x ** 2, axis_name='i') >>> >>> print(f1(jnp.arange(6.))) [0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333] >>> print(f2(jnp.array([2., 3.]))) [ 13. 13.]

jax.
devices
(backend=None)[source]¶ Returns a list of all devices for a given backend.
Each device is represented by a subclass of
Device
(e.g.CpuDevice
,GpuDevice
). The length of the returned list is equal todevice_count(backend)
. Local devices can be identified by comparingDevice.process_index
to the value returned byjax.process_index()
.If
backend
isNone
, returns all the devices from the default backend. The default backend is generally'gpu'
or'tpu'
if available, otherwise'cpu'
.

jax.
local_devices
(process_index=None, backend=None, host_id=None)[source]¶ Like
jax.devices()
, but only returns devices local to a given process.If
process_index
isNone
, returns devices local to this process. Parameters
 Return type
 Returns
List of Device subclasses.

jax.
process_index
(backend=None)[source]¶ Returns the integer process index of this process.
On most platforms, this will always be 0. This will vary on multiprocess platforms though.

jax.
device_count
(backend=None)[source]¶ Returns the total number of devices.
On most platforms, this is the same as
jax.local_device_count()
. However, on multiprocess platforms where different devices are associated with different processes, this will return the total number of devices across all processes.