# Public API: jax package¶

## Just-in-time compilation (jit)¶

 jit(fun[, static_argnums, device, backend, …]) Sets up fun for just-in-time compilation with XLA. Context manager that disables jit() behavior under its dynamic context. xla_computation(fun[, static_argnums, …]) Creates a function that produces its XLA computation given example args. make_jaxpr(fun[, static_argnums]) Creates a function that produces its jaxpr given example args. eval_shape(fun, *args, **kwargs) Compute the shape/dtype of fun without any FLOPs. device_put(x[, device]) Transfers x to device.

## Automatic differentiation¶

 grad(fun[, argnums, has_aux, holomorphic]) Creates a function which evaluates the gradient of fun. value_and_grad(fun[, argnums, has_aux, …]) Create a function which evaluates both fun and the gradient of fun. jacfwd(fun[, argnums, holomorphic]) Jacobian of fun evaluated column-by-column using forward-mode AD. jacrev(fun[, argnums, holomorphic]) Jacobian of fun evaluated row-by-row using reverse-mode AD. hessian(fun[, argnums, holomorphic]) Hessian of fun as a dense array. jvp(fun, primals, tangents) Computes a (forward-mode) Jacobian-vector product of fun. linearize(fun, *primals) Produces a linear approximation to fun using jvp() and partial eval. vjp(fun, *primals, **kwargs) Compute a (reverse-mode) vector-Jacobian product of fun. custom_jvp(fun[, nondiff_argnums]) Set up a JAX-transformable function for a custom JVP rule definition. custom_vjp(fun[, nondiff_argnums]) Set up a JAX-transformable function for a custom VJP rule definition. checkpoint(fun[, concrete]) Make fun recompute internal linearization points when differentiated.

## Vectorization (vmap)¶

 vmap(fun[, in_axes, out_axes]) Vectorizing map. jax.numpy.vectorize(pyfunc, *[, excluded, …]) Define a vectorized function with broadcasting.

## Parallelization (pmap)¶

 pmap(fun[, axis_name, in_axes, …]) Parallel map with support for collectives. devices([backend]) Returns a list of all devices for a given backend. local_devices([host_id, backend]) Like jax.devices(), but only returns devices local to a given host. host_id([backend]) Returns the integer host ID of this host. host_ids([backend]) Returns a sorted list of all host IDs. device_count([backend]) Returns the total number of devices. local_device_count([backend]) Returns the number of devices on this host. host_count([backend]) Returns the number of hosts.
jax.jit(fun, static_argnums=(), device=None, backend=None, donate_argnums=())[source]

Sets up fun for just-in-time compilation with XLA.

Parameters
• fun (Callable) – Function to be jitted. Should be a pure function, as side-effects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.

• static_argnums (Union[int, Iterable[int]]) – An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded in Python (during tracing), and so the corrersponding argument values can be any Python object. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().

• device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0].

• backend (Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend: 'cpu', 'gpu', or 'tpu'.

• donate_argnums (Union[int, Iterable[int]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perfom a computation, for example recycling one of your input buffers to store a result. You should not re-use buffers that you donate to a computation, JAX will raise an error if you try to.

Return type

Callable

Returns

A wrapped version of fun, set up for just-in-time compilation.

In the following example, selu can be compiled into a single fused kernel by XLA:

>>> import jax
>>>
>>> @jax.jit
... def selu(x, alpha=1.67, lmbda=1.05):
...   return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x))
[-0.54485  0.27744 -0.29255 -0.91421 -0.62452 -0.24748
-0.85743 -0.78232  0.76827  0.59566 ]

jax.disable_jit()[source]

Context manager that disables jit() behavior under its dynamic context.

For debugging it is useful to have a mechanism that disables jit() everywhere in a dynamic context.

Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a ShapedArray instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign side-effecting operation in a jitted function, like a print:

>>> import jax
>>>
>>> @jax.jit
... def f(x):
...   y = x * 2
...   print("Value of y is", y)
...   return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=-1/1)>
[5 7 9]


Here y has been abstracted by jit() to a ShapedArray, which represents an array with a fixed shape and type but an arbitrary value. The value of y is also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use the disable_jit() context manager:

>>> import jax
>>>
>>> with jax.disable_jit():
...   print(f(jax.numpy.array([1, 2, 3])))
...
Value of y is [2 4 6]
[5 7 9]

jax.xla_computation(fun, static_argnums=(), axis_env=None, backend=None, tuple_args=False, instantiate_const_outputs=None, return_shape=False)[source]

Creates a function that produces its XLA computation given example args.

Parameters
Return type

Callable

Returns

A wrapped version of fun that when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods like as_hlo_text, as_serialized_hlo_module_proto, and as_hlo_dot_graph. If the argument return_shape is True, then the wrapped function eturns a pair where the first element is the XLA Computation and the second element is a pytree representing the structure, shapes, and dtypes of the output of fun.

For example:

>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> c = jax.xla_computation(f)(3.)
>>> print(c.as_hlo_text())
HloModule xla_computation_f.6

ENTRY xla_computation_f.6 {
constant.2 = pred[] constant(false)
parameter.1 = f32[] parameter(0)
cosine.3 = f32[] cosine(parameter.1)
sine.4 = f32[] sine(cosine.3)
ROOT tuple.5 = (f32[]) tuple(sine.4)
}



Here’s an example that involves a parallel collective and axis name:

>>> def f(x): return x - jax.lax.psum(x, 'i')
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)
>>> print(c.as_hlo_text())
HloModule jaxpr_computation.9
primitive_computation.3 {
parameter.4 = s32[] parameter(0)
parameter.5 = s32[] parameter(1)
}
ENTRY jaxpr_computation.9 {
tuple.1 = () tuple()
parameter.2 = s32[] parameter(0)
all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
}



Notice the replica_groups that were generated. Here’s an example that generates more interesting replica_groups:

>>> from jax import lax
>>> def g(x):
...   rowsum = lax.psum(x, 'i')
...   colsum = lax.psum(x, 'j')
...   allsum = lax.psum(x, ('i', 'j'))
...   return rowsum, colsum, allsum
...
>>> axis_env = [('i', 4), ('j', 2)]
>>> c = xla_computation(g, axis_env=axis_env)(5.)
>>> print(c.as_hlo_text())
HloModule jaxpr_computation__1.19
[removed uninteresting text here]
ENTRY jaxpr_computation__1.19 {
tuple.1 = () tuple()
parameter.2 = f32[] parameter(0)
all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
}

jax.make_jaxpr(fun, static_argnums=())[source]

Creates a function that produces its jaxpr given example args.

Parameters
Return type
Returns

A wrapped version of fun that when applied to example arguments returns

a TypedJaxpr representation of fun on those arguments.

A jaxpr is JAX’s intermediate representation for program traces. The jaxpr language is based on the simply-typed first-order lambda calculus with let-bindings. make_jaxpr() adapts a function to return its jaxpr, which we can inspect to understand what JAX is doing internally. The jaxpr returned is a trace of fun abstracted to ShapedArray level. Other levels of abstraction exist internally.

We do not describe the semantics of the jaxpr language in detail here, but instead give a few examples.

>>> import jax
>>>
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602
>>> jax.make_jaxpr(f)(3.0)
{ lambda  ; a.
let b = cos a
c = sin b
in (c,) }
{ lambda  ; a.
let b = cos a
c = cos b
d = mul 1.0 c
e = neg d
f = sin a
g = mul e f
in (g,) }

jax.eval_shape(fun, *args, **kwargs)[source]

Compute the shape/dtype of fun without any FLOPs.

This utility function is useful for performing shape inference. Its input/output behavior is defined by:

def eval_shape(fun, *args, **kwargs):
out = fun(*args, **kwargs)
return jax.tree_util.tree_map(shape_dtype_struct, out)

def shape_dtype_struct(x):
return ShapeDtypeStruct(x.shape, x.dtype)

class ShapeDtypeStruct(object):
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype


In particular, the output is a pytree of objects that have shape and dtype attributes, but nothing else about them is guaranteed by the API.

But instead of applying fun directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.

Using eval_shape() can also catch shape errors, and will raise same shape errors as evaluating fun(*args, **kwargs).

Parameters
• fun (Callable) – The function whose output shape should be evaluated.

• *args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the shape and dtype attributes are accessed, only values that duck-type arrays are required, rather than real ndarrays. The duck-typed objects cannot be namedtuples because those are treated as standard Python containers. See the example below.

• **kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in args, array values need only be duck-typed to have shape and dtype attributes.

For example:

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> f = lambda A, x: jnp.tanh(jnp.dot(A, x))
>>> class MyArgArray(object):
...   def __init__(self, shape, dtype):
...     self.shape = shape
...     self.dtype = dtype
...
>>> A = MyArgArray((2000, 3000), jnp.float32)
>>> x = MyArgArray((3000, 1000), jnp.float32)
>>> out = jax.eval_shape(f, A, x)  # no FLOPs performed
>>> print(out.shape)
(2000, 1000)
>>> print(out.dtype)
float32

jax.device_put(x, device=None)[source]

Transfers x to device.

Parameters
• x – An array, scalar, or (nested) standard Python container thereof.

• device (Optional[Device]) – The (optional) Device to which x should be transferred. If given, then the result is committed to the device.

If the device parameter is None, then this operation behaves like the identity function if the operand is on any device already, otherwise it transfers the data to the default device, uncommitted.

For more details on data placement see the FAQ on data placement.

Returns

A copy of x that resides on device.

jax.grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Creates a function which evaluates the gradient of fun.

Parameters
• fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. Argument arrays in the positions specified by argnums must be of inexact (i.e., floating-point or complex) type. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

• argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

• has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

• holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.

Return type

Callable

Returns

A function with the same arguments as fun, that evaluates the gradient of fun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a pair of (gradient, auxiliary_data) is returned.

For example:

>>> import jax
>>>
0.961043

jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Create a function which evaluates both fun and the gradient of fun.

Parameters
• fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)

• argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

• has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

• holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. If True, inputs and outputs must be complex. Default False.

Return type
Returns

A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a two-element tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.

jax.jacfwd(fun, argnums=0, holomorphic=False)[source]

Jacobian of fun evaluated column-by-column using forward-mode AD.

Parameters
Return type

Callable

Returns

A function with the same arguments as fun, that evaluates the Jacobian of fun using forward-mode automatic differentiation.

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacfwd(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
[ 0.       0.       5.     ]
[ 0.      16.      -2.     ]
[ 1.6209   0.       0.84147]]

jax.jacrev(fun, argnums=0, holomorphic=False)[source]

Jacobian of fun evaluated row-by-row using reverse-mode AD.

Parameters
Return type

Callable

Returns

A function with the same arguments as fun, that evaluates the Jacobian of fun using reverse-mode automatic differentiation.

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x):
...   return jnp.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jnp.sin(x[0])])
...
>>> print(jax.jacrev(f)(jnp.array([1., 2., 3.])))
[[ 1.       0.       0.     ]
[ 0.       0.       5.     ]
[ 0.      16.      -2.     ]
[ 1.6209   0.       0.84147]]

jax.hessian(fun, argnums=0, holomorphic=False)[source]

Hessian of fun as a dense array.

Parameters
• fun (Callable) – Function whose Hessian is to be computed. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers thereof. It should return arrays, scalars, or standard Python containers thereof.

• argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).

• holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.

Return type

Callable

Returns

A function with the same arguments as fun, that evaluates the Hessian of fun.

>>> import jax
>>>
>>> g = lambda x: x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.   -2.]
[  -2. -480.]]


hessian() is a generalization of the usual definition of the Hessian that supports nested Python containers (i.e. pytrees) as inputs and outputs. The tree structure of jax.hessian(fun)(x) is given by forming a tree product of the structure of fun(x) with a tree product of two copies of the structure of x. A tree product of two tree structures is formed by replacing each leaf of the first tree with a copy of the second. For example:

>>> import jax.numpy as jnp
>>> f = lambda dct: {"c": jnp.power(dct["a"], dct["b"])}
>>> print(jax.hessian(f)({"a": jnp.arange(2.) + 1., "b": jnp.arange(2.) + 2.}))
{'c': {'a': {'a': DeviceArray([[[ 2.,  0.], [ 0.,  0.]],
[[ 0.,  0.], [ 0., 12.]]], dtype=float32),
'b': DeviceArray([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
[[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32)},
'b': {'a': DeviceArray([[[ 1.      ,  0.      ], [ 0.      ,  0.      ]],
[[ 0.      ,  0.      ], [ 0.      , 12.317766]]], dtype=float32),
'b': DeviceArray([[[0.      , 0.      ], [0.      , 0.      ]],
[[0.      , 0.      ], [0.      , 3.843624]]], dtype=float32)}}}


Thus each leaf in the tree structure of jax.hessian(fun)(x) corresponds to a leaf of fun(x) and a pair of leaves of x. For each leaf in jax.hessian(fun)(x), if the corresponding array leaf of fun(x) has shape (out_1, out_2, ...) and the corresponding array leaves of x have shape (in_1_1, in_1_2, ...) and (in_2_1, in_2_2, ...) respectively, then the Hessian leaf has shape (out_1, out_2, ..., in_1_1, in_1_2, ..., in_2_1, in_2_2, ...). In other words, the Python tree structure represents the block structure of the Hessian, with blocks determined by the input and output pytrees.

In particular, an array is produced (with no pytrees involved) when the function input x and output fun(x) are each a single array, as in the g example above. If fun(x) has shape (out1, out2, ...) and x has shape (in1, in2, ...) then jax.hessian(fun)(x) has shape (out1, out2, ..., in1, in2, ..., in1, in2, ...). To flatten pytrees into 1D vectors, consider using jax.flatten_util.flatten_pytree().

jax.jvp(fun, primals, tangents)[source]

Computes a (forward-mode) Jacobian-vector product of fun.

Parameters
• fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.

• primals – The primal values at which the Jacobian of fun should be evaluated. Should be either a tuple or a list of arguments, and its length should equal to the number of positional parameters of fun.

• tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as primals.

Return type
Returns

A (primals_out, tangents_out) pair, where primals_out is fun(*primals), and tangents_out is the Jacobian-vector product of function evaluated at primals with tangents. The tangents_out value has the same Python tree structure and shapes as primals_out.

For example:

>>> import jax
>>>
>>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
>>> print(y)
0.09983342
>>> print(v)
0.19900084

jax.linearize(fun, *primals)[source]

Produces a linear approximation to fun using jvp() and partial eval.

Parameters
• fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.

• primals – The primal values at which the Jacobian of fun should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters of fun.

Return type
Returns

A pair where the first element is the value of f(*primals) and the second element is a function that evaluates the (forward-mode) Jacobian-vector product of fun evaluated at primals without re-doing the linearization work.

In terms of values computed, linearize() behaves much like a curried jvp(), where these two code blocks compute the same values:

y, out_tangent = jax.jvp(f, (x,), (in_tangent,))

y, f_jvp = jax.linearize(f, x)
out_tangent = f_jvp(in_tangent)


However, the difference is that linearize() uses partial evaluation so that the function f is not re-linearized on calls to f_jvp. In general that means the memory usage scales with the size of the computation, much like in reverse-mode. (Indeed, linearize() has a similar signature to vjp()!)

This function is mainly useful if you want to apply f_jvp multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize using vmap(), as in:

pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))


By using vmap() and jvp() together like this we avoid the stored-linearization memory cost that scales with the depth of the computation, which is incurred by both linearize() and vjp().

Here’s a more complete example of using linearize():

>>> import jax
>>> import jax.numpy as jnp
>>>
>>> def f(x): return 3. * jnp.sin(x) + jnp.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(DeviceArray(3.26819, dtype=float32), DeviceArray(-5.00753, dtype=float32))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
>>> print(f_jvp(3.))
-5.007528
>>> print(f_jvp(4.))
-6.676704

jax.vjp(fun, *primals, **kwargs)[source]

Compute a (reverse-mode) vector-Jacobian product of fun.

grad() is implemented as a special case of vjp().

Parameters
• fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.

• primals – A sequence of primal values at which the Jacobian of fun should be evaluated. The length of primals should be equal to the number of positional parameters to fun. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.

• has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.

Return type
Returns

If has_aux is False, returns a (primals_out, vjpfun) pair, where primals_out is fun(*primals). vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same shape as primals, representing the vector-Jacobian product of fun evaluated at primals. If has_aux is True, returns a (primals_out, vjpfun, aux) tuple where aux is the auxiliary data returned by fun.

>>> import jax
>>>
>>> def f(x, y):
...   return jax.numpy.sin(x), jax.numpy.cos(y)
...
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
>>> xbar, ybar = f_vjp((-0.7, 0.3))
>>> print(xbar)
-0.61430776
>>> print(ybar)
-0.2524413

jax.custom_jvp(fun, nondiff_argnums=())[source]

Set up a JAX-transformable function for a custom JVP rule definition.

This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like jax.jvp() or jax.grad()) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method, defjvp, which defines the custom JVP rule.

For example:

import jax.numpy as jnp

@jax.custom_jvp
def f(x, y):
return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out


For a more detailed introduction, see the tutorial.

jax.custom_vjp(fun, nondiff_argnums=())[source]

Set up a JAX-transformable function for a custom VJP rule definition.

This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like jax.grad()) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method, defvjp, which defines the custom VJP rule.

This decorator precludes the use of forward-mode automatic differentiation.

For example:

import jax.numpy as jnp

@jax.custom_vjp
def f(x, y):
return jnp.sin(x) * y

def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd)


For a more detailed introduction, see the tutorial.

jax.vmap(fun, in_axes=0, out_axes=0)[source]

Vectorizing map. Creates a function which maps fun over argument axes.

Parameters
• fun (Callable) – Function to be mapped over additional axes.

• in_axes

A nonnegative integer, None, or (nested) standard Python container (tuple/list/dict) thereof specifying which input array axes to map over. If each positional argument to fun is an array, then in_axes can be a nonnegative integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments to fun. An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. If the positional arguments to fun are container types, the corresponding element of in_axes can itself be a matching container, so that distinct array axes can be mapped for different container elements. in_axes must be a container tree prefix of the positional argument tuple passed to fun.

At least one positional argument must have in_axes not None. The sizes of the mapped input axes for all mapped positional arguments must all be equal.

• out_axes – A nonnegative integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. All outputs with a mapped axis must have a non-None out_axes specification.

Return type

Callable

Returns

Batched/vectorized version of fun with arguments that correspond to those of fun, but with extra array axes at positions indicated by in_axes, and a return value that corresponds to that of fun, but with extra array axes at positions indicated by out_axes.

For example, we can implement a matrix-matrix product using a vector dot product:

>>> import jax.numpy as jnp
>>>
>>> vv = lambda x, y: jnp.vdot(x, y)  #  ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)


Here we use [a,b] to indicate an array with shape (a,b). Here are some variants:

>>> mv1 = vmap(vv, (0, 0), 0)   #  ([b,a], [b,a]) -> [b]        (b is the mapped axis)
>>> mv2 = vmap(vv, (0, 1), 0)   #  ([b,a], [a,b]) -> [b]        (b is the mapped axis)
>>> mm2 = vmap(mv2, (1, 1), 0)  #  ([b,c,a], [a,c,b]) -> [c,b]  (c is the mapped axis)


Here’s an example of using container types in in_axes to specify which axes of the container elements to map over:

>>> A, B, C, D = 2, 3, 4, 5
>>> x = jnp.ones((A, B))
>>> y = jnp.ones((B, C))
>>> z = jnp.ones((C, D))
>>> def foo(tree_arg):
...   x, (y, z) = tree_arg
...   return jnp.dot(x, jnp.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6  # batch size
>>> x = jnp.ones((K, A, B))  # batch axis in different locations
>>> y = jnp.ones((B, K, C))
>>> z = jnp.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree).shape)
(6, 2, 5)


Here’s another example using container types in in_axes, this time a dictionary, to specify the elements of the container to map over:

>>> dct = {'a': 0., 'b': jnp.arange(5.)}
>>> x = 1.
>>> def foo(dct, x):
...  return dct['a'] + dct['b'] + x
>>> out = vmap(foo, in_axes=({'a': None, 'b': 0}, None))(dct, x)
>>> print(out)
[1. 2. 3. 4. 5.]


The results of a vectorized function can be mapped or unmapped. For example, the function below returns a pair with the first element mapped and the second unmapped. Only for unmapped results we can specify out_axes to be None (to keep it unmapped).

>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=(0, None))(jnp.arange(2.), 4.))
(DeviceArray([4., 5.], dtype=float32), 8.0)


If the out_axes is specified for an unmapped result, the result is broadcast across the mapped axis:

>>> print(vmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None), out_axes=0)(jnp.arange(2.), 4.))
(DeviceArray([4., 5.], dtype=float32), DeviceArray([8., 8.], dtype=float32))


If the out_axes is specified for a mapped result, the result is transposed accordingly.

jax.numpy.vectorize(pyfunc, *, excluded=frozenset({}), signature=None)[source]

Define a vectorized function with broadcasting.

vectorize() is a convenience wrapper for defining vectorized functions with broadcasting, in the style of NumPy’s generalized universal functions. It allows for defining functions that are automatically repeated across any leading dimensions, without the implementation of the function needing to be concerned about how to handle higher dimensional inputs.

jax.numpy.vectorize() has the same interface as numpy.vectorize, but it is syntactic sugar for an auto-batching transformation (vmap()) rather than a Python loop. This should be considerably more efficient, but the implementation must be written in terms of functions that act on JAX arrays.

Parameters
• pyfunc – function to vectorize.

• excluded – optional set of integers representing positional arguments for which the function will not be vectorized. These will be passed directly to pyfunc unmodified.

• signature – optional generalized universal function signature, e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication. If provided, pyfunc will be called with (and expected to return) arrays with shapes given by the size of corresponding core dimensions. By default, pyfunc is assumed to take scalars arrays as input and output.

Returns

Vectorized version of the given function.

Here a few examples of how one could write vectorized linear algebra routines using vectorize():

import jax.numpy as jnp
from functools import partial

@partial(jnp.vectorize, signature='(k),(k)->(k)')
def cross_product(a, b):
assert a.shape == b.shape and a.ndim == b.ndim == 1
return jnp.array([a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0]])

@partial(jnp.vectorize, signature='(n,m),(m)->(n)')
def matrix_vector_product(matrix, vector):
assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape
return matrix @ vector


These functions are only written to handle 1D or 2D arrays (the assert statements will never be violated), but with vectorize they support arbitrary dimensional inputs with NumPy style broadcasting, e.g.,

>>> cross_product(jnp.ones(3), jnp.ones(3)).shape
(3,)
>>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2, 3)
>>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape
(2, 2, 3)
>>> matrix_vector_product(jnp.ones(3), jnp.ones(3))
ValueError: input with shape (3,) does not have enough dimensions for all
core dimensions ('n', 'k') on vectorized function with excluded=frozenset()
and signature='(n,k),(k)->(k)'
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape
(2,)
>>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape
(4, 2)  # not the same as jnp.matmul

jax.pmap(fun, axis_name=None, *, in_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None, donate_argnums=())[source]

Parallel map with support for collectives.

The purpose of pmap() is to express single-program multiple-data (SPMD) programs. Applying pmap() to a function will compile the function with XLA (similarly to jit()), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable to vmap() because both transformations map a function over array axes, but where vmap() vectorizes functions by pushing the mapped axis down into primitive operations, pmap() instead replicates the function and executes each replica on its own XLA device in parallel.

Another key difference with vmap() is that while vmap() can only express pure maps, pmap() enables the use of parallel SPMD collective operations, like all-reduce sum.

The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by jax.local_device_count() (unless devices is specified, see below). For nested pmap() calls, the product of the mapped axis sizes must be less than or equal to the number of XLA devices.

Note

pmap() compiles fun, so while it can be combined with jit(), it’s usually unnecessary.

Multi-host platforms: On multi-host platforms such as TPU pods, pmap() is designed to be used in SPMD Python programs, where every host is running the same Python code such that all hosts run the same pmapped function in the same order. Each host should still call the pmapped function with mapped axis size equal to the number of local devices (unless devices is specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations in fun will be computed over all participating devices, including those on other hosts, via device-to-device communication. Conceptually, this can be thought of as running a pmap over a single array sharded across hosts, where each host “sees” only its local shard of the input and output. The SPMD model requires that the same multi-host pmaps must be run in the same order on all devices, but they can be interspersed with arbitrary operations running on a single host.

Parameters
• fun (Callable) – Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_broadcasted_argnums can be anything at all, provided they are hashable and have an equality operation defined.

• axis_name (Optional[Any]) – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.

• in_axes – A nonnegative integer, None, or nested Python container thereof that specifies which axes in the input to map over (see vmap()). Currently, only 0 and None are supported axes for pmap.

• static_broadcasted_argnums (Union[int, Iterable[int]]) – An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the pmaped function with different values for these constants will trigger recompilation. If the pmaped function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Each of the static arguments will be broadcasted to all devices. Arguments that are not arrays or containers thereof must be marked as static. Defaults to ().

• devices – This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). If specified, the size of the mapped axis must be equal to the number of local devices in the sequence. Nested pmap() s with devices specified in either the inner or outer pmap() are not yet supported.

• backend (Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the XLA backend. ‘cpu’, ‘gpu’, or ‘tpu’.

• axis_size (Optional[int]) – Optional; the size of the mapped axis.

• donate_argnums (Union[int, Iterable[int]]) – Specify which arguments are “donated” to the computation. It is safe to donate arguments if you no longer need them once the computation has finished. In some cases XLA can make use of donated buffers to reduce the amount of memory needed to perfom a computation, for example recycling one of your input buffers to store a result. You should not re-use buffers that you donate to a computation, JAX will raise an error if you try to.

Return type

Callable

Returns

A parallelized version of fun with arguments that correspond to those of fun but with extra array axes at positions indicated by in_axes and with output that has an additional leading array axis (with the same size).

For example, assuming 8 XLA devices are available, pmap() can be used as a map along a leading array axis:

>>> import jax.numpy as jnp
>>>
>>> out = pmap(lambda x: x ** 2)(jnp.arange(8))
>>> print(out)
[0, 1, 4, 9, 16, 25, 36, 49]


When the leading dimension is smaller than the number of available devices JAX will simply run on a subset of devices:

>>> x = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = jnp.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(jnp.dot)(x, y)
>>> print(out)
[[[    4.     9.]
[   12.    29.]]
[[  244.   345.]
[  348.   493.]]
[[ 1412.  1737.]
[ 1740.  2141.]]]


If your leading dimension is larger than the number of available devices you will get an error:

>>> pmap(lambda x: x ** 2)(jnp.arange(9))
ValueError: ... requires 9 replicas, but only 8 XLA devices are available


As with vmap(), using None in in_axes indicates that an argument doesn’t have an extra axis and should be broadcasted, rather than mapped, across the replicas:

>>> x, y = jnp.arange(2.), 4.
>>> out = pmap(lambda x, y: (x + y, y * 2.), in_axes=(0, None))(x, y)
>>> print(out)
([4., 5.], [8., 8.])


Note that pmap() always returns values mapped over their leading axis, equivalent to using out_axes=0 in vmap().

In addition to expressing pure maps, pmap() can also be used to express parallel single-program multiple-data (SPMD) programs that communicate via collective operations. For example:

>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
>>> out = pmap(f, axis_name='i')(jnp.arange(4.))
>>> print(out)
[ 0.          0.16666667  0.33333334  0.5       ]
>>> print(out.sum())
1.0


In this example, axis_name is a string, but it can be any Python object with __hash__ and __eq__ defined.

The argument axis_name to pmap() names the mapped axis so that collective operations, like jax.lax.psum(), can refer to it. Axis names are important particularly in the case of nested pmap() functions, where collectives can operate over distinct axes:

>>> from functools import partial
>>> import jax
>>>
>>> @partial(pmap, axis_name='rows')
... @partial(pmap, axis_name='cols')
... def normalize(x):
...   row_normed = x / jax.lax.psum(x, 'rows')
...   col_normed = x / jax.lax.psum(x, 'cols')
...   doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
...   return row_normed, col_normed, doubly_normed
>>>
>>> x = jnp.arange(8.).reshape((4, 2))
>>> row_normed, col_normed, doubly_normed = normalize(x)
>>> print(row_normed.sum(0))
[ 1.  1.]
>>> print(col_normed.sum(1))
[ 1.  1.  1.  1.]
>>> print(doubly_normed.sum((0, 1)))
1.0


On multi-host platforms, collective operations operate over all devices, including those those on other hosts. For example, assuming the following code runs on two hosts with 4 XLA devices each:

>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = jnp.arange(4) if jax.host_id() == 0 else jnp.arange(4,8)
>>> out = pmap(f, axis_name='i')(data)
>>> print(out)
[28 29 30 31] # on host 0
[32 33 34 35] # on host 1


Each host passes in a different length-4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length-4 arrays can be thought of as sharded length-8 array (in this example equivalent to jnp.arange(8)) that is mapped over, with the length-8 mapped axis given name ‘i’. The pmap call on each host then returns the corresponding length-4 output shard.

The devices argument can be used to specify exactly which devices are used to run the parallel computation. For example, again assuming a single host with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two:

>>> from functools import partial
>>> @partial(pmap, axis_name='i', devices=jax.devices()[:6])
... def f1(x):
...   return x / jax.lax.psum(x, axis_name='i')
>>>
>>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:])
... def f2(x):
...   return jax.lax.psum(x ** 2, axis_name='i')
>>>
>>> print(f1(jnp.arange(6.)))
[0.         0.06666667 0.13333333 0.2        0.26666667 0.33333333]
>>> print(f2(jnp.array([2., 3.])))
[ 13.  13.]

jax.devices(backend=None)[source]

Returns a list of all devices for a given backend.

Each device is represented by a subclass of Device (e.g. CpuDevice, GpuDevice). The length of the returned list is equal to device_count(backend). Local devices can be identified by comparing Device.host_id() to the value returned by jax.host_id().

If backend is None, returns all the devices from the default backend. The default backend is generally 'gpu' or 'tpu' if available, otherwise 'cpu'.

Parameters

backend (Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: 'cpu', 'gpu', or 'tpu'.

Returns

List of Device subclasses.

jax.local_devices(host_id=None, backend=None)[source]

Like jax.devices(), but only returns devices local to a given host.

If host_id is None, returns devices local to this host.

Parameters
Returns

List of Device subclasses.

jax.host_id(backend=None)[source]

Returns the integer host ID of this host.

On most platforms, this will always be 0. This will vary on multi-host platforms though.

Parameters

backend (Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: 'cpu', 'gpu', or 'tpu'.

Returns

Integer host ID.

jax.host_ids(backend=None)[source]

Returns a sorted list of all host IDs.

Parameters

backend (Optional[str]) –

jax.device_count(backend=None)[source]

Returns the total number of devices.

On most platforms, this is the same as jax.local_device_count(). However, on multi-host platforms, this will return the total number of devices across all hosts.

Parameters

backend (Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend: 'cpu', 'gpu', or 'tpu'.

Returns

Number of devices.

jax.local_device_count(backend=None)[source]

Returns the number of devices on this host.

Parameters

backend (Optional[str]) –

jax.host_count(backend=None)[source]

Returns the number of hosts.

Parameters

backend (Optional[str]) –