jax package¶
Subpackages¶
Justintime compilation (jit
)¶

jax.
jit
(fun, static_argnums=(), device=None, backend=None)[source]¶ Sets up fun for justintime compilation with XLA.
Parameters:  fun –
Function to be jitted. Should be a pure function, as sideeffects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.
 static_argnums – A tuple of ints specifying which positional arguments to treat as static (compiletime constant). Operations that only depend on static arguments will be constantfolded. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Defaults to ().
 device – This is an experimental feature and the API is likely to change.
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via
jax.devices()
.) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0]
.  backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’,’gpu’, or ‘tpu’.
Returns: A wrapped version of fun, set up for justintime compilation.
In the following example, selu can be compiled into a single fused kernel by XLA:
>>> @jax.jit >>> def selu(x, alpha=1.67, lmbda=1.05): >>> return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x)  alpha) >>> >>> key = jax.random.PRNGKey(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [0.54485154 0.27744263 0.29255125 0.91421586 0.62452525 0.2474813 0.8574326 0.7823267 0.7682731 0.59566754]
 fun –

jax.
disable_jit
()[source]¶ Context manager that disables jit behavior under its dynamic context.
For debugging purposes, it is useful to have a mechanism that disables jit everywhere in a dynamic context.
Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a ShapedArray instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign sideeffecting operation in a jitted function, like a print:
>>> @jax.jit >>> def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=1/1)> [5 7 9]
Here y has been abstracted by jit to a ShapedArray, which represents an array with a fixed shape and type but an arbitrary value. It’s also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use the disable_jit context manager:
>>> with jax.disable_jit(): >>> print(f(np.array([1, 2, 3]))) >>> Value of y is [2 4 6] [5 7 9]

jax.
xla_computation
(fun, static_argnums=(), axis_env=None, backend=None, tuple_args=False)[source]¶ Creates a function that produces its XLA computation given example args.
Parameters:  fun – Function from which to form XLA computations.
 static_argnums – See the
jax.jit
docstring.  axis_env – Optional, a list of pairs where the first element is an axis name
and the second element is a positive integer representing the size of the
mapped axis with that name. This parameter is useful when lowering
functions that involve parallel communication collectives, and it
specifies the axis name/size environment that would be set up by
applications of
jax.pmap
. See the examples below.  backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’,’gpu’, or ‘tpu’.
 tuple_args – Optional, defaults to False. If True, the resulting XLA computation will have a single tuple argument that is unpacked into the specified function arguments.
Returns: A wrapped version of
fun
that when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods likeGetHloText
,GetSerializedProto
, andGetHloDotGraph
.For example:
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> c = jax.xla_computation(f)(3.) >>> print(c.GetHloText()) HloModule jaxpr_computation__4.5 ENTRY jaxpr_computation__4.5 { tuple.1 = () tuple() parameter.2 = f32[] parameter(0) cosine.3 = f32[] cosine(parameter.2) ROOT sine.4 = f32[] sine(cosine.3) }
Here’s an example that involves a parallel collective and axis name:
>>> def f(x): return x  jax.lax.psum(x, 'i') >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) >>> print(c.GetHloText()) HloModule jaxpr_computation.9 primitive_computation.3 { parameter.4 = s32[] parameter(0) parameter.5 = s32[] parameter(1) ROOT add.6 = s32[] add(parameter.4, parameter.5) } ENTRY jaxpr_computation.9 { tuple.1 = () tuple() parameter.2 = s32[] parameter(0) allreduce.7 = s32[] allreduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 ROOT subtract.8 = s32[] subtract(parameter.2, allreduce.7) }
Notice the
replica_groups
that were generated. Here’s an example that generates more interestingreplica_groups
:>>> def g(x): ... rowsum = lax.psum(x, 'i') ... colsum = lax.psum(x, 'j') ... allsum = lax.psum(x, ('i', 'j')) ... return rowsum, colsum, allsum ... >>> axis_env = [('i', 4), ('j', 2)] >>> c = xla_computation(g, axis_env=axis_env)(5.) >>> print(c.GetHloText()) HloModule jaxpr_computation__1.19 [removed uninteresting text here] ENTRY jaxpr_computation__1.19 { tuple.1 = () tuple() parameter.2 = f32[] parameter(0) allreduce.7 = f32[] allreduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 allreduce.12 = f32[] allreduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 allreduce.17 = f32[] allreduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 ROOT tuple.18 = (f32[], f32[], f32[]) tuple(allreduce.7, allreduce.12, allreduce.17) }

jax.
make_jaxpr
(fun)[source]¶ Creates a function that produces its jaxpr given example args.
Parameters: fun – The function whose jaxpr is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof. Returns: A wrapped version of fun that when applied to example arguments returns a jaxpr representation of fun on those arguments. A jaxpr is JAX’s intermediate representation for program traces. The jaxpr language is based on the simplytyped firstorder lambda calculus with letbindings. make_jaxpr adapts a function to return its jaxpr, which we can inspect to understand what JAX is doing internally.
The jaxpr returned is a trace of fun abstracted to ShapedArray level. Other levels of abstraction exist internally.
We do not describe the semantics of the jaxpr language in detail here, but instead give a few examples.
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) 0.83602184 >>> jax.make_jaxpr(f)(3.0) { lambda ; ; a. let b = cos a c = sin b in c } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda b ; ; a. let c = pack a (d) = id c e = cos d f = cos e g = mul b f h = neg g i = sin d j = mul h i k = pack j (l) = id k in l }

jax.
eval_shape
(fun, *args, **kwargs)[source]¶ Compute the shape/dtype of
fun(*args, **kwargs)
without any FLOPs.This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) return jax.tree_util.tree_map(shape_dtype_struct, out) def shape_dtype_struct(x): return ShapeDtypeStruct(x.shape, x.dtype) class ShapeDtypeStruct(object): __slots__ = ["shape", "dtype"] def __init__(self, shape, dtype): self.shape = shape self.dtype = dtype
In particular, the output is a pytree of objects that have
shape
anddtype
attributes, but nothing else about them is guaranteed by the API.But instead of applying
fun
directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.Using
eval_shape
can also catch shape errors, and will raise same shape errors as evaluatingfun(*args, **kwargs)
.Parameters:  *args – a positional argument tuple of arrays, scalars, or (nested) standard
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
those types. Since only the
shape
anddtype
attributes are accessed, only values that ducktype arrays are required, rather than real ndarrays. The ducktyped objects cannot be namedtuples because those are treated as standard Python containers. See the example below.  **kwargs – a keyword argument dict of arrays, scalars, or (nested) standard
Python containers (pytrees) of those types. As in
args
, array values need only be ducktyped to haveshape
anddtype
attributes.
For example:
>>> f = lambda A, x: np.tanh(np.dot(A, x)) >>> class MyArgArray(object): ... def __init__(self, shape, dtype): ... self.shape = shape ... self.dtype = dtype ... >>> A = MyArgArray((2000, 3000), np.float32) >>> x = MyArgArray((3000, 1000), np.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) dtype('float32')
 *args – a positional argument tuple of arrays, scalars, or (nested) standard
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
those types. Since only the
Automatic differentiation¶

jax.
grad
(fun, argnums=0, has_aux=False, holomorphic=False)[source]¶ Creates a function which evaluates the gradient of fun.
Parameters:  fun – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)
 argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
 has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
 holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Returns: A function with the same arguments as fun, that evaluates the gradient of fun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a pair of (gradient, auxiliary_data) is returned.
For example:
>>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043

jax.
value_and_grad
(fun, argnums=0, has_aux=False, holomorphic=False)[source]¶ Creates a function which evaluates both fun and the gradient of fun.
Parameters:  fun – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)
 argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
 has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
 holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Returns: A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a twoelement tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.

jax.
jacfwd
(fun, argnums=0, holomorphic=False)[source]¶ Jacobian of fun evaluated columnbycolumn using forwardmode AD.
Parameters:  fun – Function whose Jacobian is to be computed.
 argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
 holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Returns: A function with the same arguments as fun, that evaluates the Jacobian of fun using forwardmode automatic differentiation.
>>> def f(x): ... return jax.numpy.asarray( ... [x[0], 5*x[2], 4*x[1]**2  2*x[2], x[2] * jax.numpy.sin(x[0])]) ... >>> print(jax.jacfwd(f)(np.array([1., 2., 3.]))) [[ 1. , 0. , 0. ], [ 0. , 0. , 5. ], [ 0. , 16. , 2. ], [ 1.6209068 , 0. , 0.84147096]]

jax.
jacrev
(fun, argnums=0, holomorphic=False)[source]¶ Jacobian of fun evaluated rowbyrow using reversemode AD.
Parameters:  fun – Function whose Jacobian is to be computed.
 argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
 holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Returns: A function with the same arguments as fun, that evaluates the Jacobian of fun using reversemode automatic differentiation.
>>> def f(x): ... return jax.numpy.asarray( ... [x[0], 5*x[2], 4*x[1]**2  2*x[2], x[2] * jax.numpy.sin(x[0])]) ... >>> print(jax.jacrev(f)(np.array([1., 2., 3.]))) [[ 1. , 0. , 0. ], [ 0. , 0. , 5. ], [ 0. , 16. , 2. ], [ 1.6209068 , 0. , 0.84147096]]

jax.
hessian
(fun, argnums=0, holomorphic=False)[source]¶ Hessian of fun.
Parameters:  fun – Function whose Hessian is to be computed.
 argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).
 holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Returns: A function with the same arguments as fun, that evaluates the Hessian of fun.
>>> g = lambda(x): x[0]**3  2*x[0]*x[1]  x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6., 2.], [ 2., 480.]]

jax.
jvp
(fun, primals, tangents)[source]¶ Computes a (forwardmode) Jacobianvector product of fun.
Parameters:  fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.
 primals – The primal values at which the Jacobian of fun should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters of fun.
 tangents – The tangent vector for which the Jacobianvector product should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof, with the same tree structure and array shapes as primals.
Returns: A (primals_out, tangents_out) pair, where primals_out is fun(*primals), and tangents_out is the Jacobianvector product of function evaluated at primals with tangents. The tangents_out value has the same Python tree structure and shapes as primals_out.
For example:
>>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(y) 0.09983342 >>> print(v) 0.19900084

jax.
linearize
(fun, *primals)[source]¶ Produce a linear approximation to fun using jvp and partial evaluation.
Parameters:  fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.
 primals – The primal values at which the Jacobian of fun should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters of fun.
Returns: A pair where the first element is the value of f(*primals) and the second element is a function that evaluates the (forwardmode) Jacobianvector product of fun evaluated at primals without redoing the linearization work.
In terms of values computed, linearize behaves much like a curried jvp, where these two code blocks compute the same values:
y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
However, the difference is that linearize uses partial evaluation so that the function f is not relinearized on calls to f_jvp. In general that means the memory usage scales with the size of the computation, much like in reversemode. (Indeed, linearize has a similar signature to vjp!)
This function is mainly useful if you want to apply f_jvp multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize using vmap, as in:
pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using vmap and jvp together like this we avoid the storedlinearization memory cost that scales with the depth of the computation, which is incurred by both linearize and vjp.
Here’s a more complete example of using linearize:
>>> def f(x): return 3. * np.sin(x) + np.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (array(3.2681944, dtype=float32), array(5.007528, dtype=float32)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) 5.007528 >>> print(f_jvp(4.)) 6.676704

jax.
vjp
(fun, *primals, **kwargs)[source]¶ Compute a (reversemode) vectorJacobian product of fun.
grad is implemented as a special case of vjp.
Parameters:  fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.
 primals – A sequence of primal values at which the Jacobian of fun should be evaluated. The length of primals should be equal to the number of positional parameters to fun. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.
 has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
Returns: A (primals_out, vjpfun) pair, where primals_out is fun(*primals). vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same shape as primals, representing the vectorJacobian product of fun evaluated at primals.
>>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((0.7, 0.3)) >>> print(xbar) 0.61430776 >>> print(ybar) 0.2524413

jax.
custom_transforms
(fun)[source]¶ Wraps a function so that its transformation behavior can be controlled.
A primary use case of
custom_transforms
is defining custom VJP rules (aka custom gradients) for a Python function, while still supporting other transformations likejax.jit
andjax.vmap
. Custom differentiation rules can be supplied using thejax.defjvp
andjax.defvjp
functions.The
custom_transforms
decorator wrapsfun
so that its transformation behavior can be overridden, but not all transformation rules need to be specified manually. The default behavior is retained for any nonoverridden rules.The function
fun
must satisfy the same constraints required for jit compilation. In particular the shapes of arrays in the computation offun
may depend on the shapes offun
’s arguments, but not their values. Value dependent Python control flow is also not yet supported.Parameters: fun – a Python callable. Must be functionally pure. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Returns: A Python callable with the same input/output and transformation behavior as fun
, but for which custom transformation rules can be supplied, e.g. usingjax.defvjp
.For example:
>>> @jax.custom_transforms ... def f(x): ... return np.sin(x ** 2) ... >>> print(f(3.)) 0.4121185 >>> print(jax.grad(f)(3.)) 5.4667816 >>> jax.defvjp(f, lambda g, x: g * x) >>> print(jax.grad(f)(3.)) 3.0

jax.
defjvp
(fun, *jvprules)[source]¶ Definine JVP rules for each argument separately.
This function is a convenience wrapper around
jax.defjvp_all
for separately defining JVP rules for each of the function’s arguments. This convenience wrapper does not provide a mechanism for depending on anything other than the function arguments and its primal output value, though depending on intermediate results is possible usingjax.defjvp_all
.The signature of each component JVP rule is
lambda g, ans, *primals: ...
whereg
represents the tangent of the corresponding positional argument,ans
represents the output primal, and*primals
represents all the primal positional arguments.Defining a custom JVP rule also affects the default VJP rule, which is derived from the JVP rule automatically via transposition.
Parameters:  fun – a custom_transforms function.
 *jvprules – a sequence of functions or Nones specifying the JVP rule for each corresponding positional argument. When an element is None, it indicates that the Jacobian from the corresponding input to the output is zero.
Returns: None. A sideeffect is that
fun
is associated with the JVP rule specified by*jvprules
.For example:
>>> @jax.custom_transforms ... def f(x): ... return np.sin(x ** 2) ... >>> print(f(3.)) 0.4121185 >>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,)) >>> print(out_primal) 0.4121185 >>> print(out_tangent) 10.933563 >>> jax.defjvp(f, lambda g, ans, x: 8. * g + ans) >>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,)) >>> print(out_primal) 0.4121185 >>> print(out_tangent) 16.412119

jax.
defjvp_all
(fun, custom_jvp)[source]¶ Define a custom JVP rule for a
custom_transforms
function.If
fun
represents a function with signaturea > b
, thencustom_jvp
represents a function with signature(a, T a) > (b, T b)
, where we useT x
to represent a tangent type for the typex
.In more detail,
custom_jvp
must take two arguments, both tuples of length equal to the number of positional arguments tofun
. The first argument tocustom_jvp
represents the input primal values, and the second represents the input tangent values.custom_jvp
must return a pair where the first element represents the output primal value and the second element represents the output tangent value.Defining a custom JVP rule also affects the default VJP rule, which is derived from the JVP rule automatically via transposition.
Parameters:  fun – a custom_transforms function.
 custom_jvp – a Python callable specifying the JVP rule, taking two tuples as arguments specifying the input primal values and tangent values, respectively. The tuple elements can be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. The output must be a pair representing the primal output and tangent output, which can be arrays, scalars, or (nested) standard Python containers. Must be functionally pure.
Returns: None. A sideeffect is that
fun
is associated with the JVP rule specified bycustom_jvp
.For example:
>>> @jax.custom_transforms ... def f(x): ... return np.sin(x ** 2) ... >>> print(f(3.)) 0.4121185 >>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,)) >>> print(out_primal) 0.4121185 >>> print(out_tangent) 10.933563 >>> jax.defjvp_all(f, lambda ps, ts: (np.sin(ps[0] ** 2), 8. * ts[0])) >>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,)) >>> print(out_primal) 0.4121185 >>> print(out_tangent) 16.0

jax.
defvjp
(fun, *vjprules)[source]¶ Define VJP rules for each argument separately.
This function is a convenience wrapper around
jax.defvjp_all
for separately defining VJP rules for each of the function’s arguments. This convenience wrapper does not provide a mechanism for depending on anything other than the function arguments and its primal output value, though depending on intermediate results is possible usingjax.defvjp_all
.The signature of each component VJP rule is
lambda g, ans, *primals: ...
whereg
represents the output cotangent,ans
represents the output primal, and*primals
represents all the primal positional arguments.Parameters:  fun – a custom_transforms function.
 *vjprules – a sequence of functions or Nones specifying the VJP rule for each corresponding positional argument. When an element is None, it indicates that the Jacobian from the corresponding input to the output is zero.
Returns: None. A sideeffect is that
fun
is associated with the VJP rule specified by*vjprules
.For example:
>>> @jax.custom_transforms ... def f(x, y): ... return np.sin(x ** 2 + y) ... >>> print(f(3., 4.)) 0.42016703 >>> print(jax.grad(f)(3., 4.)) 5.4446807 >>> print(jax.grad(f, 1)(3., 4.)) 0.9074468 >>> jax.defvjp(f, None, lambda g, ans, x, y: g + x + y + ans) >>> print(jax.grad(f)(3., 4.)) 0.0 >>> print(jax.grad(f, 1)(3., 4.)) 8.420167

jax.
defvjp_all
(fun, custom_vjp)[source]¶ Define a custom VJP rule for a
custom_transforms
function.If
fun
represents a function with signaturea > b
, thencustom_vjp
represents a function with signaturea > (b, CT b > CT a)
where we useCT x
to represent a cotangent type for the typex
. That is,custom_vjp
should take the same arguments asfun
and return a pair where the first element represents the primal value offun
applied to the arguments, and the second element is a VJP function that maps from output cotangents to input cotangents, returning a tuple with length equal to the number of positional arguments supplied tofun
.The VJP function returned as the second element of the output of
custom_vjp
can close over intermediate values computed when evaluating the primal value offun
. That is, use lexical closure to share work between the forward pass and the backward pass of reversemode automatic differentiation.See also
jax.custom_gradient
.Parameters:  fun – a custom_transforms function.
 custom_vjp – a Python callable specifying the VJP rule, taking the same
arguments as
fun
and returning a pair where the first elment is the value offun
applied to the arguments and the second element is a Python callable representing the VJP map from output cotangents to input cotangents. The returned VJP function must accept a value with the same shape as the value offun
applied to the arguments and must return a tuple with length equal to the number of positional arguments tofun
. Arguments can be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Must be functionally pure.
Returns: None. A sideeffect is that
fun
is associated with the VJP rule specified bycustom_vjp
.For example:
>>> @jax.custom_transforms ... def f(x): ... return np.sin(x ** 2) ... >>> print(f(3.)) 0.4121185 >>> print(jax.grad(f)(3.)) 5.4667816 >>> jax.defvjp_all(f, lambda x: (np.sin(x ** 2), lambda g: (g * x,))) >>> print(f(3.)) 0.4121185 >>> print(jax.grad(f)(3.)) 3.0
An example with a function on two arguments, so that the VJP function must return a tuple of length two:
>>> @jax.custom_transforms ... def f(x, y): ... return x * y ... >>> jax.defvjp_all(f, lambda x, y: (x * y, lambda g: (y, x))) >>> print(f(3., 4.)) 12.0 >>> print(jax.grad(f, argnums=(0, 1))(3., 4.)) (4.0, 3.0)

jax.
custom_gradient
(fun)[source]¶ Convenience function for defining custom VJP rules (aka custom gradients).
While the canonical way to define custom VJP rules is via
jax.defvjp_all
and its convenience wrappers, thecustom_gradient
convenience wrapper follows TensorFlow’stf.custom_gradient
API. The difference here is thatcustom_gradient
can be used as a decorator on one function that returns both the primal value (representing the output of the mathematical function to be differentiated) and the VJP (gradient) function.See https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
If the mathematical function to be differentiated has type signature
a > b
, then the Python callablefun
should have signaturea > (b, CT b > CT a)
where we useCT x
to denote a cotangent type forx
. See the example below. That is,fun
should return a pair where the first element represents the value of the mathematical function to be differentiated and the second element is a function that represents the custom VJP rule.The custom VJP function returned as the second element of the output of
fun
can close over intermediate values computed when evaluating the function to be differentiated. That is, use lexical closure to share work between the forward pass and the backward pass of reversemode automatic differentiation.Parameters: fun – a Python callable specifying both the mathematical function to be differentiated and its reversemode differentiation rule. It should return a pair consisting of an output value and a Python callable that represents the custom gradient function. Returns: A Python callable with signature a > b
, i.e. that returns the output value specified by the first element offun
’s output pair. A side effect is that underthehoodjax.defvjp_all
is called to set up the returned Python callable with the custom VJP rule specified by the second element offun
’s output pair.For example:
>>> @jax.custom_gradient ... def f(x): ... return x ** 2, lambda g: (g * x,) ... >>> print(f(3.)) 9.0 >>> print(jax.grad(f)(3.)) 3.0
An example with a function on two arguments, so that the VJP function must return a tuple of length two:
>>> @jax.custom_gradient ... def f(x, y): ... return x * y, lambda g: (y, x) ... >>> print(f(3., 4.)) 12.0 >>> print(jax.grad(f, argnums=(0, 1))(3., 4.)) (4.0, 3.0)
Vectorization (vmap
)¶

jax.
vmap
(fun, in_axes=0, out_axes=0)[source]¶ Vectorizing map. Creates a function which maps fun over argument axes.
Parameters:  fun – Function to be mapped over additional axes.
 in_axes – A nonnegative integer, None, or (nested) standard Python container
(tuple/list/dict) thereof specifying which input array axes to map over.
If each positional argument to
fun
is an array, thenin_axes
can be a nonnegative integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments tofun
. An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. If the positional arguments tofun
are container types, the corresponding element ofin_axes
can itself be a matching container, so that distinct array axes can be mapped for different container elements.in_axes
must be a container tree prefix of the positional argument tuple passed tofun
.  out_axes – A nonnegative integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output.
Returns: Batched/vectorized version of
fun
with arguments that correspond to those offun
, but with extra array axes at positions indicated byin_axes
, and a return value that corresponds to that offun
, but with extra array axes at positions indicated byout_axes
.For example, we can implement a matrixmatrix product using a vector dot product:
>>> vv = lambda x, y: np.vdot(x, y) # ([a], [a]) > [] >>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) > [b] (b is the mapped axis) >>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) > [b,c] (c is the mapped axis)
Here we use
[a,b]
to indicate an array with shape (a,b). Here are some variants:>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) > [b] (b is the mapped axis) >>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) > [b] (b is the mapped axis) >>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) > [c,b] (c is the mapped axis)
Here’s an example of using container types in
in_axes
to specify which axes of the container elements to map over:>>> A, B, C, D = 2, 3, 4, 5 >>> x = np.ones((A, B)) >>> y = np.ones((B, C)) >>> z = np.ones((C, D)) >>> def foo(tree_arg): ... x, (y, z) = tree_arg ... return np.dot(x, np.dot(y, z)) >>> tree = (x, (y, z)) >>> print(foo(tree)) [[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]] >>> from jax import vmap >>> K = 6 # batch size >>> x = np.ones((K, A, B)) # batch axis in different locations >>> y = np.ones((B, K, C)) >>> z = np.ones((C, D, K)) >>> tree = (x, (y, z)) >>> vfoo = vmap(foo, in_axes=((0, (1, 2)),)) >>> print(vfoo(tree)).shape (6, 2, 5)
Parallelization (pmap
)¶

jax.
pmap
(fun, axis_name=None, devices=None, backend=None)[source]¶ Parallel map with support for collectives.
The purpose of
pmap
is to express singleprogram multipledata (SPMD) programs and execute them in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable tovmap
because both transformations map a function over array axes, but wherevmap
vectorizes functions by pushing the mapped axis down into primitive operations,pmap
instead replicates the function and executes each replica on its own XLA device in parallel.Another key difference with
vmap
is that whilevmap
can only express pure maps,pmap
enables the use of parallel SPMD collective operations, like allreduce sum.The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by
jax.local_device_count()
(unlessdevices
is specified, see below). For nestedpmap
calls, the product of the mapped axis sizes must be less than or equal to the number of XLA devices.Multihost platforms: On multihost platforms such as TPU pods,
pmap
is designed to be used in SPMD Python programs, where every host is running the same Python code such that all hosts run the same pmapped function in the same order. Each host should still call the pmapped function with mapped axis size equal to the number of local devices (unlessdevices
is specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations infun
will be computed over all participating devices, including those on other hosts, via devicetodevice communication. Conceptually, this can be thought of as running a pmap over a single array sharded across hosts, where each host “sees” only its local shard of the input and output.Parameters:  fun – Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
 axis_name – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.
 devices – This is an experimental feature and the API is likely to change.
Optional, a sequence of Devices to map over. (Available devices can be
retrieved via jax.devices()). If specified, the size of the mapped axis
must be equal to the number of local devices in the sequence. Nested
pmap
s withdevices
specified in either the inner or outerpmap
are not yet supported.  backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’.
Returns: A parallelized version of
fun
with arguments that correspond to those offun
but each with an additional leading array axis (with equal sizes) and with output that has an additional leading array axis (with the same size).For example, assuming 8 XLA devices are available,
pmap
can be used as a map along a leading array axes:>>> out = pmap(lambda x: x ** 2)(np.arange(8)) >>> print(out) [0, 1, 4, 9, 16, 25, 36, 49] >>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) >>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 >>> out = pmap(np.dot)(x, y) >>> print(out) [[[ 4. 9.] [ 12. 29.]] [[ 244. 345.] [ 348. 493.]] [[ 1412. 1737.] [ 1740. 2141.]]]
In addition to expressing pure maps,
pmap
can also be used to express parallel singleprogram multipledata (SPMD) programs that communicate via collective operations. For example:>>> f = lambda x: x / jax.lax.psum(x, axis_name='i') >>> out = pmap(f, axis_name='i')(np.arange(4.)) >>> print(out) [ 0. 0.16666667 0.33333334 0.5 ] >>> print(out.sum()) 1.0
In this example,
axis_name
is a string, but it can be any Python object with__hash__
and__eq__
defined.The argument
axis_name
topmap
names the mapped axis so that collective operations, likejax.lax.psum
, can refer to it. Axis names are important particularly in the case of nestedpmap
functions, where collectives can operate over distinct axes:>>> from functools import partial >>> @partial(pmap, axis_name='rows') >>> @partial(pmap, axis_name='cols') >>> def normalize(x): >>> row_normed = x / jax.lax.psum(x, 'rows') >>> col_normed = x / jax.lax.psum(x, 'cols') >>> doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) >>> return row_normed, col_normed, doubly_normed >>> >>> x = np.arange(8.).reshape((4, 2)) >>> row_normed, col_normed, doubly_normed = normalize(x) >>> print(row_normed.sum(0)) [ 1. 1.] >>> print(col_normed.sum(1)) [ 1. 1. 1. 1.] >>> print(doubly_normed.sum((0, 1))) 1.0
On multihost platforms, collective operations operate over all devices, including those those on other hosts. For example, assuming the following code runs on two hosts with 4 XLA devices each:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i') >>> data = np.arange(4) if jax.host_id() == 0 else np.arange(4,8) >>> out = pmap(f, axis_name='i')(data) >>> print(out) [28 29 30 31] # on host 0 [32 33 34 35] # on host 1
Each host passes in a different length4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length4 arrays can be thought of as sharded length16 array (in this example equivalent to np.arange(8)) that is mapped over, with the length8 mapped axis given name ‘i’. The pmap call on each host then returns the corresponding length4 output shard.
The
devices
argument can be used to specify exactly which devices are used to run the parallel computation. For example, again assuming a single host with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two:>>> from functools import partial >>> @partial(pmap, axis_name='i', devices=jax.devices()[:6]) >>> def f1(x): >>> return x / jax.lax.psum(x, axis_name='i') >>> >>> @partial(pmap, axis_name='i', devices=jax.devices()[2:]) >>> def f2(x): >>> return jax.lax.psum(x ** 2, axis_name='i') >>> >>> print(f1(np.arange(6.))) [0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333] >>> print(f2(np.array([2., 3.]))) [ 13. 13.]