# jax package¶

## Just-in-time compilation (jit)¶

jax.jit(fun, static_argnums=(), device=None, backend=None)[source]

Sets up fun for just-in-time compilation with XLA.

Parameters: fun – Function to be jitted. Should be a pure function, as side-effects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined. static_argnums – A tuple of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Defaults to (). device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0]. backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’,’gpu’, or ‘tpu’. A wrapped version of fun, set up for just-in-time compilation.

In the following example, selu can be compiled into a single fused kernel by XLA:

>>> @jax.jit
>>> def selu(x, alpha=1.67, lmbda=1.05):
>>>   return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x))
[-0.54485154  0.27744263 -0.29255125 -0.91421586 -0.62452525 -0.2474813
-0.8574326  -0.7823267   0.7682731   0.59566754]

jax.disable_jit()[source]

Context manager that disables jit behavior under its dynamic context.

For debugging purposes, it is useful to have a mechanism that disables jit everywhere in a dynamic context.

Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a ShapedArray instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign side-effecting operation in a jitted function, like a print:

>>> @jax.jit
>>> def f(x):
...   y = x * 2
...   print("Value of y is", y)
...   return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=-1/1)>
[5 7 9]


Here y has been abstracted by jit to a ShapedArray, which represents an array with a fixed shape and type but an arbitrary value. It’s also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use the disable_jit context manager:

>>> with jax.disable_jit():
>>>   print(f(np.array([1, 2, 3])))
>>>
Value of y is [2 4 6]
[5 7 9]

jax.xla_computation(fun, static_argnums=(), axis_env=None, backend=None, tuple_args=False, instantiate_const_outputs=True)[source]

Creates a function that produces its XLA computation given example args.

Parameters: fun – Function from which to form XLA computations. static_argnums – See the jax.jit docstring. axis_env – Optional, a list of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications of jax.pmap. See the examples below. backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’,’gpu’, or ‘tpu’. tuple_args – Optional bool, defaults to False. If True, the resulting XLA computation will have a single tuple argument that is unpacked into the specified function arguments. instantiate_const_outputs – Optional bool, defaults to True. If False, then xla_computation does not instantiate constant-valued outputs in the XLA computation, and so the result is closer to the computation that jax.jit produces and may be more useful for studying jit behavior. If True, then constant-valued outputs are instantiated in the XLA computation, which may be more useful for staging computations out of JAX entirely. A wrapped version of fun that when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods like GetHloText, GetSerializedProto, and GetHloDotGraph.

For example:

>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> c = jax.xla_computation(f)(3.)
>>> print(c.GetHloText())
HloModule jaxpr_computation__4.5
ENTRY jaxpr_computation__4.5 {
tuple.1 = () tuple()
parameter.2 = f32[] parameter(0)
cosine.3 = f32[] cosine(parameter.2)
ROOT sine.4 = f32[] sine(cosine.3)
}


Here’s an example that involves a parallel collective and axis name:

>>> def f(x): return x - jax.lax.psum(x, 'i')
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)
>>> print(c.GetHloText())
HloModule jaxpr_computation.9
primitive_computation.3 {
parameter.4 = s32[] parameter(0)
parameter.5 = s32[] parameter(1)
}
ENTRY jaxpr_computation.9 {
tuple.1 = () tuple()
parameter.2 = s32[] parameter(0)
all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
}


Notice the replica_groups that were generated. Here’s an example that generates more interesting replica_groups:

>>> def g(x):
...   rowsum = lax.psum(x, 'i')
...   colsum = lax.psum(x, 'j')
...   allsum = lax.psum(x, ('i', 'j'))
...   return rowsum, colsum, allsum
...
>>> axis_env = [('i', 4), ('j', 2)]
>>> c = xla_computation(g, axis_env=axis_env)(5.)
>>> print(c.GetHloText())
HloModule jaxpr_computation__1.19
[removed uninteresting text here]
ENTRY jaxpr_computation__1.19 {
tuple.1 = () tuple()
parameter.2 = f32[] parameter(0)
all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
}

jax.make_jaxpr(fun)[source]

Creates a function that produces its jaxpr given example args.

Parameters: fun – The function whose jaxpr is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof. A wrapped version of fun that when applied to example arguments returns a TypedJaxpr representation of fun on those arguments.

A jaxpr is JAX’s intermediate representation for program traces. The jaxpr language is based on the simply-typed first-order lambda calculus with let-bindings. make_jaxpr adapts a function to return its jaxpr, which we can inspect to understand what JAX is doing internally.

The jaxpr returned is a trace of fun abstracted to ShapedArray level. Other levels of abstraction exist internally.

We do not describe the semantics of the jaxpr language in detail here, but instead give a few examples.

>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602184
>>> jax.make_jaxpr(f)(3.0)
{ lambda  ;  ; a.
let b = cos a
c = sin b
in [c] }
{ lambda  ;  ; a.
let b = cos a
c = cos b
d = mul 1.0 c
e = neg d
f = sin a
g = mul e f
in [g] }

jax.eval_shape(fun, *args, **kwargs)[source]

Compute the shape/dtype of fun(*args, **kwargs) without any FLOPs.

This utility function is useful for performing shape inference. Its input/output behavior is defined by:

def eval_shape(fun, *args, **kwargs):
out = fun(*args, **kwargs)
return jax.tree_util.tree_map(shape_dtype_struct, out)

def shape_dtype_struct(x):
return ShapeDtypeStruct(x.shape, x.dtype)

class ShapeDtypeStruct(object):
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype


In particular, the output is a pytree of objects that have shape and dtype attributes, but nothing else about them is guaranteed by the API.

But instead of applying fun directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.

Using eval_shape can also catch shape errors, and will raise same shape errors as evaluating fun(*args, **kwargs).

Parameters: *args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the shape and dtype attributes are accessed, only values that duck-type arrays are required, rather than real ndarrays. The duck-typed objects cannot be namedtuples because those are treated as standard Python containers. See the example below. **kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in args, array values need only be duck-typed to have shape and dtype attributes.

For example:

>>> f = lambda A, x: np.tanh(np.dot(A, x))
>>> class MyArgArray(object):
...   def __init__(self, shape, dtype):
...     self.shape = shape
...     self.dtype = dtype
...
>>> A = MyArgArray((2000, 3000), np.float32)
>>> x = MyArgArray((3000, 1000), np.float32)
>>> out = jax.eval_shape(f, A, x)  # no FLOPs performed
>>> print(out.shape)
(2000, 1000)
>>> print(out.dtype)
dtype('float32')


## Automatic differentiation¶

jax.grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Creates a function which evaluates the gradient of fun.

Parameters: fun – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.) argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. A function with the same arguments as fun, that evaluates the gradient of fun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a pair of (gradient, auxiliary_data) is returned.

For example:

>>> grad_tanh = jax.grad(jax.numpy.tanh)
0.961043

jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Creates a function which evaluates both fun and the gradient of fun.

Parameters: fun – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.) argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a two-element tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.
jax.jacfwd(fun, argnums=0, holomorphic=False)[source]

Jacobian of fun evaluated column-by-column using forward-mode AD.

Parameters: fun – Function whose Jacobian is to be computed. argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. A function with the same arguments as fun, that evaluates the Jacobian of fun using forward-mode automatic differentiation.
>>> def f(x):
...   return jax.numpy.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
...
>>> print(jax.jacfwd(f)(np.array([1., 2., 3.])))
[[ 1.        ,  0.        ,  0.        ],
[ 0.        ,  0.        ,  5.        ],
[ 0.        , 16.        , -2.        ],
[ 1.6209068 ,  0.        ,  0.84147096]]

jax.jacrev(fun, argnums=0, holomorphic=False)[source]

Jacobian of fun evaluated row-by-row using reverse-mode AD.

Parameters: fun – Function whose Jacobian is to be computed. argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. A function with the same arguments as fun, that evaluates the Jacobian of fun using reverse-mode automatic differentiation.
>>> def f(x):
...   return jax.numpy.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
...
>>> print(jax.jacrev(f)(np.array([1., 2., 3.])))
[[ 1.        ,  0.        ,  0.        ],
[ 0.        ,  0.        ,  5.        ],
[ 0.        , 16.        , -2.        ],
[ 1.6209068 ,  0.        ,  0.84147096]]

jax.hessian(fun, argnums=0, holomorphic=False)[source]

Hessian of fun.

Parameters: fun – Function whose Hessian is to be computed. argnums – Optional, integer or tuple of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). holomorphic – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. A function with the same arguments as fun, that evaluates the Hessian of fun.
>>> g = lambda(x): x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.,   -2.],
[  -2., -480.]]

jax.jvp(fun, primals, tangents)[source]

Computes a (forward-mode) Jacobian-vector product of fun.

Parameters: fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. primals – The primal values at which the Jacobian of fun should be evaluated. Should be either a tuple or a list of arguments, and its length should equal to the number of positional parameters of fun. tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as primals. A (primals_out, tangents_out) pair, where primals_out is fun(*primals), and tangents_out is the Jacobian-vector product of function evaluated at primals with tangents. The tangents_out value has the same Python tree structure and shapes as primals_out.

For example:

>>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
>>> print(y)
0.09983342
>>> print(v)
0.19900084

jax.linearize(fun, *primals)[source]

Produce a linear approximation to fun using jvp and partial evaluation.

Parameters: fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars. primals – The primal values at which the Jacobian of fun should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters of fun. A pair where the first element is the value of f(*primals) and the second element is a function that evaluates the (forward-mode) Jacobian-vector product of fun evaluated at primals without re-doing the linearization work.

In terms of values computed, linearize behaves much like a curried jvp, where these two code blocks compute the same values:

y, out_tangent = jax.jvp(f, (x,), (in_tangent,))

y, f_jvp = jax.linearize(f, x)
out_tangent = f_jvp(in_tangent)


However, the difference is that linearize uses partial evaluation so that the function f is not re-linearized on calls to f_jvp. In general that means the memory usage scales with the size of the computation, much like in reverse-mode. (Indeed, linearize has a similar signature to vjp!)

This function is mainly useful if you want to apply f_jvp multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize using vmap, as in:

pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))


By using vmap and jvp together like this we avoid the stored-linearization memory cost that scales with the depth of the computation, which is incurred by both linearize and vjp.

Here’s a more complete example of using linearize:

>>> def f(x): return 3. * np.sin(x) + np.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(array(3.2681944, dtype=float32), array(-5.007528, dtype=float32))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
>>> print(f_jvp(3.))
-5.007528
>>> print(f_jvp(4.))
-6.676704

jax.vjp(fun, *primals, **kwargs)[source]

Compute a (reverse-mode) vector-Jacobian product of fun.

grad is implemented as a special case of vjp.

Parameters: fun – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. primals – A sequence of primal values at which the Jacobian of fun should be evaluated. The length of primals should be equal to the number of positional parameters to fun. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof. has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. A (primals_out, vjpfun) pair, where primals_out is fun(*primals). vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same shape as primals, representing the vector-Jacobian product of fun evaluated at primals.
>>> def f(x, y):
...   return jax.numpy.sin(x), jax.numpy.cos(y)
...
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
>>> xbar, ybar = f_vjp((-0.7, 0.3))
>>> print(xbar)
-0.61430776
>>> print(ybar)
-0.2524413

jax.custom_transforms(fun)[source]

Wraps a function so that its transformation behavior can be controlled.

A primary use case of custom_transforms is defining custom VJP rules (aka custom gradients) for a Python function, while still supporting other transformations like jax.jit and jax.vmap. Custom differentiation rules can be supplied using the jax.defjvp and jax.defvjp functions.

The custom_transforms decorator wraps fun so that its transformation behavior can be overridden, but not all transformation rules need to be specified manually. The default behavior is retained for any non-overridden rules.

The function fun must satisfy the same constraints required for jit compilation. In particular the shapes of arrays in the computation of fun may depend on the shapes of fun’s arguments, but not their values. Value dependent Python control flow is also not yet supported.

Parameters: fun – a Python callable. Must be functionally pure. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. A Python callable with the same input/output and transformation behavior as fun, but for which custom transformation rules can be supplied, e.g. using jax.defvjp.

For example:

>>> @jax.custom_transforms
... def f(x):
...   return np.sin(x ** 2)
...
>>> print(f(3.))
0.4121185
-5.4667816
>>> jax.defvjp(f, lambda g, x: g * x)
3.0

jax.defjvp(fun, *jvprules)[source]

Definine JVP rules for each argument separately.

This function is a convenience wrapper around jax.defjvp_all for separately defining JVP rules for each of the function’s arguments. This convenience wrapper does not provide a mechanism for depending on anything other than the function arguments and its primal output value, though depending on intermediate results is possible using jax.defjvp_all.

The signature of each component JVP rule is lambda g, ans, *primals: ... where g represents the tangent of the corresponding positional argument, ans represents the output primal, and *primals represents all the primal positional arguments.

Defining a custom JVP rule also affects the default VJP rule, which is derived from the JVP rule automatically via transposition.

Parameters: fun – a custom_transforms function. *jvprules – a sequence of functions or Nones specifying the JVP rule for each corresponding positional argument. When an element is None, it indicates that the Jacobian from the corresponding input to the output is zero. None. A side-effect is that fun is associated with the JVP rule specified by *jvprules.

For example:

>>> @jax.custom_transforms
... def f(x):
...   return np.sin(x ** 2)
...
>>> print(f(3.))
0.4121185
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
-10.933563
>>> jax.defjvp(f, lambda g, ans, x: 8. * g + ans)
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
16.412119

jax.defjvp_all(fun, custom_jvp)[source]

Define a custom JVP rule for a custom_transforms function.

If fun represents a function with signature a -> b, then custom_jvp represents a function with signature (a, T a) -> (b, T b), where we use T x to represent a tangent type for the type x.

In more detail, custom_jvp must take two arguments, both tuples of length equal to the number of positional arguments to fun. The first argument to custom_jvp represents the input primal values, and the second represents the input tangent values. custom_jvp must return a pair where the first element represents the output primal value and the second element represents the output tangent value.

Defining a custom JVP rule also affects the default VJP rule, which is derived from the JVP rule automatically via transposition.

Parameters: fun – a custom_transforms function. custom_jvp – a Python callable specifying the JVP rule, taking two tuples as arguments specifying the input primal values and tangent values, respectively. The tuple elements can be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. The output must be a pair representing the primal output and tangent output, which can be arrays, scalars, or (nested) standard Python containers. Must be functionally pure. None. A side-effect is that fun is associated with the JVP rule specified by custom_jvp.

For example:

>>> @jax.custom_transforms
... def f(x):
...   return np.sin(x ** 2)
...
>>> print(f(3.))
0.4121185
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
-10.933563
>>> jax.defjvp_all(f, lambda ps, ts: (np.sin(ps[0] ** 2), 8. * ts[0]))
>>> out_primal, out_tangent = jax.jvp(f, (3.,), (2.,))
>>> print(out_primal)
0.4121185
>>> print(out_tangent)
16.0

jax.defvjp(fun, *vjprules)[source]

Define VJP rules for each argument separately.

This function is a convenience wrapper around jax.defvjp_all for separately defining VJP rules for each of the function’s arguments. This convenience wrapper does not provide a mechanism for depending on anything other than the function arguments and its primal output value, though depending on intermediate results is possible using jax.defvjp_all.

The signature of each component VJP rule is lambda g, ans, *primals: ... where g represents the output cotangent, ans represents the output primal, and *primals represents all the primal positional arguments.

Parameters: fun – a custom_transforms function. *vjprules – a sequence of functions or Nones specifying the VJP rule for each corresponding positional argument. When an element is None, it indicates that the Jacobian from the corresponding input to the output is zero. None. A side-effect is that fun is associated with the VJP rule specified by *vjprules.

For example:

>>> @jax.custom_transforms
... def f(x, y):
...   return np.sin(x ** 2 + y)
...
>>> print(f(3., 4.))
0.42016703
5.4446807
>>> print(jax.grad(f, 1)(3., 4.))
0.9074468
>>> jax.defvjp(f, None, lambda g, ans, x, y: g + x + y + ans)
0.0
>>> print(jax.grad(f, 1)(3., 4.))
8.420167

jax.defvjp_all(fun, custom_vjp)[source]

Define a custom VJP rule for a custom_transforms function.

If fun represents a function with signature a -> b, then custom_vjp represents a function with signature a -> (b, CT b -> CT a) where we use CT x to represent a cotangent type for the type x. That is, custom_vjp should take the same arguments as fun and return a pair where the first element represents the primal value of fun applied to the arguments, and the second element is a VJP function that maps from output cotangents to input cotangents, returning a tuple with length equal to the number of positional arguments supplied to fun.

The VJP function returned as the second element of the output of custom_vjp can close over intermediate values computed when evaluating the primal value of fun. That is, use lexical closure to share work between the forward pass and the backward pass of reverse-mode automatic differentiation.

See also jax.custom_gradient.

Parameters: fun – a custom_transforms function. custom_vjp – a Python callable specifying the VJP rule, taking the same arguments as fun and returning a pair where the first elment is the value of fun applied to the arguments and the second element is a Python callable representing the VJP map from output cotangents to input cotangents. The returned VJP function must accept a value with the same shape as the value of fun applied to the arguments and must return a tuple with length equal to the number of positional arguments to fun. Arguments can be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Must be functionally pure. None. A side-effect is that fun is associated with the VJP rule specified by custom_vjp.

For example:

>>> @jax.custom_transforms
... def f(x):
...   return np.sin(x ** 2)
...
>>> print(f(3.))
0.4121185
-5.4667816
>>> jax.defvjp_all(f, lambda x: (np.sin(x ** 2), lambda g: (g * x,)))
>>> print(f(3.))
0.4121185
3.0


An example with a function on two arguments, so that the VJP function must return a tuple of length two:

>>> @jax.custom_transforms
... def f(x, y):
...   return x * y
...
>>> jax.defvjp_all(f, lambda x, y: (x * y, lambda g: (y, x)))
>>> print(f(3., 4.))
12.0
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
(4.0, 3.0)

jax.custom_gradient(fun)[source]

Convenience function for defining custom VJP rules (aka custom gradients).

While the canonical way to define custom VJP rules is via jax.defvjp_all and its convenience wrappers, the custom_gradient convenience wrapper follows TensorFlow’s tf.custom_gradient API. The difference here is that custom_gradient can be used as a decorator on one function that returns both the primal value (representing the output of the mathematical function to be differentiated) and the VJP (gradient) function.

If the mathematical function to be differentiated has type signature a -> b, then the Python callable fun should have signature a -> (b, CT b -> CT a) where we use CT x to denote a cotangent type for x. See the example below. That is, fun should return a pair where the first element represents the value of the mathematical function to be differentiated and the second element is a function that represents the custom VJP rule.

The custom VJP function returned as the second element of the output of fun can close over intermediate values computed when evaluating the function to be differentiated. That is, use lexical closure to share work between the forward pass and the backward pass of reverse-mode automatic differentiation.

Parameters: fun – a Python callable specifying both the mathematical function to be differentiated and its reverse-mode differentiation rule. It should return a pair consisting of an output value and a Python callable that represents the custom gradient function. A Python callable with signature a -> b, i.e. that returns the output value specified by the first element of fun’s output pair. A side effect is that under-the-hood jax.defvjp_all is called to set up the returned Python callable with the custom VJP rule specified by the second element of fun’s output pair.

For example:

>>> @jax.custom_gradient
... def f(x):
...   return x ** 2, lambda g: (g * x,)
...
>>> print(f(3.))
9.0
3.0


An example with a function on two arguments, so that the VJP function must return a tuple of length two:

>>> @jax.custom_gradient
... def f(x, y):
...   return x * y, lambda g: (y, x)
...
>>> print(f(3., 4.))
12.0
>>> print(jax.grad(f, argnums=(0, 1))(3., 4.))
(4.0, 3.0)


## Vectorization (vmap)¶

jax.vmap(fun, in_axes=0, out_axes=0)[source]

Vectorizing map. Creates a function which maps fun over argument axes.

Parameters: fun – Function to be mapped over additional axes. in_axes – A nonnegative integer, None, or (nested) standard Python container (tuple/list/dict) thereof specifying which input array axes to map over. If each positional argument to fun is an array, then in_axes can be a nonnegative integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments to fun. An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. If the positional arguments to fun are container types, the corresponding element of in_axes can itself be a matching container, so that distinct array axes can be mapped for different container elements. in_axes must be a container tree prefix of the positional argument tuple passed to fun. out_axes – A nonnegative integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. Batched/vectorized version of fun with arguments that correspond to those of fun, but with extra array axes at positions indicated by in_axes, and a return value that corresponds to that of fun, but with extra array axes at positions indicated by out_axes.

For example, we can implement a matrix-matrix product using a vector dot product:

>>> vv = lambda x, y: np.vdot(x, y)  #  ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)


Here we use [a,b] to indicate an array with shape (a,b). Here are some variants:

>>> mv1 = vmap(vv, (0, 0), 0)   #  ([b,a], [b,a]) -> [b]        (b is the mapped axis)
>>> mv2 = vmap(vv, (0, 1), 0)   #  ([b,a], [a,b]) -> [b]        (b is the mapped axis)
>>> mm2 = vmap(mv2, (1, 1), 0)  #  ([b,c,a], [a,c,b]) -> [c,b]  (c is the mapped axis)


Here’s an example of using container types in in_axes to specify which axes of the container elements to map over:

>>> A, B, C, D = 2, 3, 4, 5
>>> x = np.ones((A, B))
>>> y = np.ones((B, C))
>>> z = np.ones((C, D))
>>> def foo(tree_arg):
...   x, (y, z) = tree_arg
...   return np.dot(x, np.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6  # batch size
>>> x = np.ones((K, A, B))  # batch axis in different locations
>>> y = np.ones((B, K, C))
>>> z = np.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree)).shape
(6, 2, 5)


## Parallelization (pmap)¶

jax.pmap(fun, axis_name=None, devices=None, backend=None, axis_size=None)[source]

Parallel map with support for collectives.

The purpose of pmap is to express single-program multiple-data (SPMD) programs. Applying pmap to a function will compile the function with XLA (similarly to jit), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable to vmap because both transformations map a function over array axes, but where vmap vectorizes functions by pushing the mapped axis down into primitive operations, pmap instead replicates the function and executes each replica on its own XLA device in parallel.

Another key difference with vmap is that while vmap can only express pure maps, pmap enables the use of parallel SPMD collective operations, like all-reduce sum.

The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by jax.local_device_count() (unless devices is specified, see below). For nested pmap calls, the product of the mapped axis sizes must be less than or equal to the number of XLA devices.

Multi-host platforms: On multi-host platforms such as TPU pods, pmap is designed to be used in SPMD Python programs, where every host is running the same Python code such that all hosts run the same pmapped function in the same order. Each host should still call the pmapped function with mapped axis size equal to the number of local devices (unless devices is specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations in fun will be computed over all participating devices, including those on other hosts, via device-to-device communication. Conceptually, this can be thought of as running a pmap over a single array sharded across hosts, where each host “sees” only its local shard of the input and output. The SPMD model requires that the same multi-host pmaps must be run in the same order on all devices, but they can be interspersed with arbitrary operations running on a single host.

Parameters: fun – Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. axis_name – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied. devices – This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). If specified, the size of the mapped axis must be equal to the number of local devices in the sequence. Nested pmap s with devices specified in either the inner or outer pmap are not yet supported. backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. A parallelized version of fun with arguments that correspond to those of fun but each with an additional leading array axis (with equal sizes) and with output that has an additional leading array axis (with the same size).

For example, assuming 8 XLA devices are available, pmap can be used as a map along a leading array axes:

>>> out = pmap(lambda x: x ** 2)(np.arange(8))
>>> print(out)
[0, 1, 4, 9, 16, 25, 36, 49]
>>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(np.dot)(x, y)
>>> print(out)
[[[    4.     9.]
[   12.    29.]]
[[  244.   345.]
[  348.   493.]]
[[ 1412.  1737.]
[ 1740.  2141.]]]


In addition to expressing pure maps, pmap can also be used to express parallel single-program multiple-data (SPMD) programs that communicate via collective operations. For example:

>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
>>> out = pmap(f, axis_name='i')(np.arange(4.))
>>> print(out)
[ 0.          0.16666667  0.33333334  0.5       ]
>>> print(out.sum())
1.0


In this example, axis_name is a string, but it can be any Python object with __hash__ and __eq__ defined.

The argument axis_name to pmap names the mapped axis so that collective operations, like jax.lax.psum, can refer to it. Axis names are important particularly in the case of nested pmap functions, where collectives can operate over distinct axes:

>>> from functools import partial
>>> @partial(pmap, axis_name='rows')
>>> @partial(pmap, axis_name='cols')
>>> def normalize(x):
>>>   row_normed = x / jax.lax.psum(x, 'rows')
>>>   col_normed = x / jax.lax.psum(x, 'cols')
>>>   doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
>>>   return row_normed, col_normed, doubly_normed
>>>
>>> x = np.arange(8.).reshape((4, 2))
>>> row_normed, col_normed, doubly_normed = normalize(x)
>>> print(row_normed.sum(0))
[ 1.  1.]
>>> print(col_normed.sum(1))
[ 1.  1.  1.  1.]
>>> print(doubly_normed.sum((0, 1)))
1.0


On multi-host platforms, collective operations operate over all devices, including those those on other hosts. For example, assuming the following code runs on two hosts with 4 XLA devices each:

>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = np.arange(4) if jax.host_id() == 0 else np.arange(4,8)
>>> out = pmap(f, axis_name='i')(data)
>>> print(out)
[28 29 30 31] # on host 0
[32 33 34 35] # on host 1


Each host passes in a different length-4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length-4 arrays can be thought of as sharded length-16 array (in this example equivalent to np.arange(8)) that is mapped over, with the length-8 mapped axis given name ‘i’. The pmap call on each host then returns the corresponding length-4 output shard.

The devices argument can be used to specify exactly which devices are used to run the parallel computation. For example, again assuming a single host with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two:

>>> from functools import partial
>>> @partial(pmap, axis_name='i', devices=jax.devices()[:6])
>>> def f1(x):
>>>   return x / jax.lax.psum(x, axis_name='i')
>>>
>>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:])
>>> def f2(x):
>>>   return jax.lax.psum(x ** 2, axis_name='i')
>>>
>>> print(f1(np.arange(6.)))
[0.         0.06666667 0.13333333 0.2        0.26666667 0.33333333]
>>> print(f2(np.array([2., 3.])))
[ 13.  13.]

jax.devices(backend=None)[source]

Returns a list of all devices.

Each device is represented by a subclass of Device (e.g. CpuDevice, GpuDevice). The length of the returned list is equal to device_count(). Local devices can be identified by comparing Device.host_id to host_id().

Parameters: backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. List of Device subclasses.
jax.local_devices(host_id=None, backend=None)[source]

Returns a list of devices local to a given host (this host by default).

jax.host_id(backend=None)[source]

Returns the integer host ID of this host.

On most platforms, this will always be 0. This will vary on multi-host platforms though.

Parameters: backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. Integer host ID.
jax.host_ids(backend=None)[source]

Returns a list of all host IDs.

jax.device_count(backend=None)[source]

Returns the total number of devices.

On most platforms, this is the same as local_device_count(). However, on multi-host platforms, this will return the total number of devices across all hosts.

Parameters: backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. Number of devices.
jax.local_device_count(backend=None)[source]

Returns the number of devices on this host.

jax.host_count(backend=None)[source]

Returns the number of hosts.