# Public API: jax package¶

## Just-in-time compilation (jit)¶

jax.jit(fun, static_argnums=(), device=None, backend=None)[source]

Sets up fun for just-in-time compilation with XLA.

Parameters: fun (Callable) – Function to be jitted. Should be a pure function, as side-effects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined. static_argnums (Union[int, Iterable[int]]) – An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Defaults to (). device – This is an experimental feature and the API is likely to change. Optional, the Device the jitted function will run on. (Available devices can be retrieved via jax.devices().) The default is inherited from XLA’s DeviceAssignment logic and is usually to use jax.devices()[0]. backend (Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’,’gpu’, or ‘tpu’. Callable A wrapped version of fun, set up for just-in-time compilation.

In the following example, selu can be compiled into a single fused kernel by XLA:

>>> @jax.jit
>>> def selu(x, alpha=1.67, lmbda=1.05):
>>>   return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x) - alpha)
>>>
>>> key = jax.random.PRNGKey(0)
>>> x = jax.random.normal(key, (10,))
>>> print(selu(x))
[-0.54485154  0.27744263 -0.29255125 -0.91421586 -0.62452525 -0.2474813
-0.8574326  -0.7823267   0.7682731   0.59566754]

jax.disable_jit()[source]

Context manager that disables jit behavior under its dynamic context.

For debugging purposes, it is useful to have a mechanism that disables jit everywhere in a dynamic context.

Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a ShapedArray instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign side-effecting operation in a jitted function, like a print:

>>> @jax.jit
>>> def f(x):
...   y = x * 2
...   print("Value of y is", y)
...   return y + 3
...
>>> print(f(jax.numpy.array([1, 2, 3])))
Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=-1/1)>
[5 7 9]


Here y has been abstracted by jit to a ShapedArray, which represents an array with a fixed shape and type but an arbitrary value. It’s also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use the disable_jit context manager:

>>> with jax.disable_jit():
>>>   print(f(np.array([1, 2, 3])))
>>>
Value of y is [2 4 6]
[5 7 9]

jax.xla_computation(fun, static_argnums=(), axis_env=None, backend=None, tuple_args=False, instantiate_const_outputs=True)[source]

Creates a function that produces its XLA computation given example args.

Parameters: fun (Callable) – Function from which to form XLA computations. static_argnums (Union[int, Iterable[int]]) – See the jax.jit docstring. axis_env (Optional[Sequence[Tuple[Any, int]]]) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications of jax.pmap. See the examples below. backend (Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’,’gpu’, or ‘tpu’. tuple_args (bool) – Optional bool, defaults to False. If True, the resulting XLA computation will have a single tuple argument that is unpacked into the specified function arguments. instantiate_const_outputs (bool) – Optional bool, defaults to True. If False, then xla_computation does not instantiate constant-valued outputs in the XLA computation, and so the result is closer to the computation that jax.jit produces and may be more useful for studying jit behavior. If True, then constant-valued outputs are instantiated in the XLA computation, which may be more useful for staging computations out of JAX entirely. Callable A wrapped version of fun that when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods like GetHloText, GetSerializedProto, and GetHloDotGraph.

For example:

>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> c = jax.xla_computation(f)(3.)
>>> print(c.GetHloText())
HloModule jaxpr_computation__4.5
ENTRY jaxpr_computation__4.5 {
tuple.1 = () tuple()
parameter.2 = f32[] parameter(0)
cosine.3 = f32[] cosine(parameter.2)
ROOT sine.4 = f32[] sine(cosine.3)
}


Here’s an example that involves a parallel collective and axis name:

>>> def f(x): return x - jax.lax.psum(x, 'i')
>>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2)
>>> print(c.GetHloText())
HloModule jaxpr_computation.9
primitive_computation.3 {
parameter.4 = s32[] parameter(0)
parameter.5 = s32[] parameter(1)
}
ENTRY jaxpr_computation.9 {
tuple.1 = () tuple()
parameter.2 = s32[] parameter(0)
all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3
ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7)
}


Notice the replica_groups that were generated. Here’s an example that generates more interesting replica_groups:

>>> def g(x):
...   rowsum = lax.psum(x, 'i')
...   colsum = lax.psum(x, 'j')
...   allsum = lax.psum(x, ('i', 'j'))
...   return rowsum, colsum, allsum
...
>>> axis_env = [('i', 4), ('j', 2)]
>>> c = xla_computation(g, axis_env=axis_env)(5.)
>>> print(c.GetHloText())
HloModule jaxpr_computation__1.19
[removed uninteresting text here]
ENTRY jaxpr_computation__1.19 {
tuple.1 = () tuple()
parameter.2 = f32[] parameter(0)
all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3
all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8
all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13
ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17)
}

jax.make_jaxpr(fun)[source]

Creates a function that produces its jaxpr given example args.

Parameters: fun (Callable) – The function whose jaxpr is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof. Callable[…, TypedJaxpr] A wrapped version of fun that when applied to example arguments returns a TypedJaxpr representation of fun on those arguments.

A jaxpr is JAX’s intermediate representation for program traces. The jaxpr language is based on the simply-typed first-order lambda calculus with let-bindings. make_jaxpr adapts a function to return its jaxpr, which we can inspect to understand what JAX is doing internally.

The jaxpr returned is a trace of fun abstracted to ShapedArray level. Other levels of abstraction exist internally.

We do not describe the semantics of the jaxpr language in detail here, but instead give a few examples.

>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x))
>>> print(f(3.0))
-0.83602184
>>> jax.make_jaxpr(f)(3.0)
{ lambda  ;  ; a.
let b = cos a
c = sin b
in [c] }
{ lambda  ;  ; a.
let b = cos a
c = cos b
d = mul 1.0 c
e = neg d
f = sin a
g = mul e f
in [g] }

jax.eval_shape(fun, *args, **kwargs)[source]

Compute the shape/dtype of fun(*args, **kwargs) without any FLOPs.

This utility function is useful for performing shape inference. Its input/output behavior is defined by:

def eval_shape(fun, *args, **kwargs):
out = fun(*args, **kwargs)
return jax.tree_util.tree_map(shape_dtype_struct, out)

def shape_dtype_struct(x):
return ShapeDtypeStruct(x.shape, x.dtype)

class ShapeDtypeStruct(object):
__slots__ = ["shape", "dtype"]
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype


In particular, the output is a pytree of objects that have shape and dtype attributes, but nothing else about them is guaranteed by the API.

But instead of applying fun directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.

Using eval_shape can also catch shape errors, and will raise same shape errors as evaluating fun(*args, **kwargs).

Parameters: *args – a positional argument tuple of arrays, scalars, or (nested) standard Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of those types. Since only the shape and dtype attributes are accessed, only values that duck-type arrays are required, rather than real ndarrays. The duck-typed objects cannot be namedtuples because those are treated as standard Python containers. See the example below. **kwargs – a keyword argument dict of arrays, scalars, or (nested) standard Python containers (pytrees) of those types. As in args, array values need only be duck-typed to have shape and dtype attributes.

For example:

>>> f = lambda A, x: np.tanh(np.dot(A, x))
>>> class MyArgArray(object):
...   def __init__(self, shape, dtype):
...     self.shape = shape
...     self.dtype = dtype
...
>>> A = MyArgArray((2000, 3000), np.float32)
>>> x = MyArgArray((3000, 1000), np.float32)
>>> out = jax.eval_shape(f, A, x)  # no FLOPs performed
>>> print(out.shape)
(2000, 1000)
>>> print(out.dtype)
dtype('float32')

jax.device_put(x, device=None)[source]

Transfers x to device.

Parameters: x – An array, scalar, or (nested) standard Python container thereof. device – The Device to transfer x to. A copy of x that resides on device. If x is already on device, returns x.

## Automatic differentiation¶

jax.grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Creates a function which evaluates the gradient of fun.

Parameters: fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.) argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. Callable A function with the same arguments as fun, that evaluates the gradient of fun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a pair of (gradient, auxiliary_data) is returned.

For example:

>>> grad_tanh = jax.grad(jax.numpy.tanh)
0.961043

jax.value_and_grad(fun, argnums=0, has_aux=False, holomorphic=False)[source]

Creates a function which evaluates both fun and the gradient of fun.

Parameters: fun (Callable) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.) argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). has_aux (bool) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. Callable A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a two-element tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.
jax.jacfwd(fun, argnums=0, holomorphic=False)[source]

Jacobian of fun evaluated column-by-column using forward-mode AD.

Parameters: fun (Callable) – Function whose Jacobian is to be computed. argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. Callable A function with the same arguments as fun, that evaluates the Jacobian of fun using forward-mode automatic differentiation.
>>> def f(x):
...   return jax.numpy.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
...
>>> print(jax.jacfwd(f)(np.array([1., 2., 3.])))
[[ 1.        ,  0.        ,  0.        ],
[ 0.        ,  0.        ,  5.        ],
[ 0.        , 16.        , -2.        ],
[ 1.6209068 ,  0.        ,  0.84147096]]

jax.jacrev(fun, argnums=0, holomorphic=False)[source]

Jacobian of fun evaluated row-by-row using reverse-mode AD.

Parameters: fun (Callable) – Function whose Jacobian is to be computed. argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. Callable A function with the same arguments as fun, that evaluates the Jacobian of fun using reverse-mode automatic differentiation.
>>> def f(x):
...   return jax.numpy.asarray(
...     [x[0], 5*x[2], 4*x[1]**2 - 2*x[2], x[2] * jax.numpy.sin(x[0])])
...
>>> print(jax.jacrev(f)(np.array([1., 2., 3.])))
[[ 1.        ,  0.        ,  0.        ],
[ 0.        ,  0.        ,  5.        ],
[ 0.        , 16.        , -2.        ],
[ 1.6209068 ,  0.        ,  0.84147096]]

jax.hessian(fun, argnums=0, holomorphic=False)[source]

Hessian of fun.

Parameters: fun (Callable) – Function whose Hessian is to be computed. argnums (Union[int, Sequence[int]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0). holomorphic (bool) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False. Callable A function with the same arguments as fun, that evaluates the Hessian of fun.
>>> g = lambda(x): x[0]**3 - 2*x[0]*x[1] - x[1]**6
>>> print(jax.hessian(g)(jax.numpy.array([1., 2.])))
[[   6.,   -2.],
[  -2., -480.]]

jax.jvp(fun, primals, tangents)[source]

Computes a (forward-mode) Jacobian-vector product of fun.

Parameters: fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. primals – The primal values at which the Jacobian of fun should be evaluated. Should be either a tuple or a list of arguments, and its length should equal to the number of positional parameters of fun. tangents – The tangent vector for which the Jacobian-vector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as primals. A (primals_out, tangents_out) pair, where primals_out is fun(*primals), and tangents_out is the Jacobian-vector product of function evaluated at primals with tangents. The tangents_out value has the same Python tree structure and shapes as primals_out.

For example:

>>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,))
>>> print(y)
0.09983342
>>> print(v)
0.19900084

jax.linearize(fun, *primals)[source]

Produce a linear approximation to fun using jvp and partial evaluation.

Parameters: fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars. primals – The primal values at which the Jacobian of fun should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters of fun. A pair where the first element is the value of f(*primals) and the second element is a function that evaluates the (forward-mode) Jacobian-vector product of fun evaluated at primals without re-doing the linearization work.

In terms of values computed, linearize behaves much like a curried jvp, where these two code blocks compute the same values:

y, out_tangent = jax.jvp(f, (x,), (in_tangent,))

y, f_jvp = jax.linearize(f, x)
out_tangent = f_jvp(in_tangent)


However, the difference is that linearize uses partial evaluation so that the function f is not re-linearized on calls to f_jvp. In general that means the memory usage scales with the size of the computation, much like in reverse-mode. (Indeed, linearize has a similar signature to vjp!)

This function is mainly useful if you want to apply f_jvp multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize using vmap, as in:

pushfwd = partial(jvp, f, (x,))
y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))


By using vmap and jvp together like this we avoid the stored-linearization memory cost that scales with the depth of the computation, which is incurred by both linearize and vjp.

Here’s a more complete example of using linearize:

>>> def f(x): return 3. * np.sin(x) + np.cos(x / 2.)
...
>>> jax.jvp(f, (2.,), (3.,))
(array(3.2681944, dtype=float32), array(-5.007528, dtype=float32))
>>> y, f_jvp = jax.linearize(f, 2.)
>>> print(y)
3.2681944
>>> print(f_jvp(3.))
-5.007528
>>> print(f_jvp(4.))
-6.676704

jax.vjp(fun, *primals, **kwargs)[source]

Compute a (reverse-mode) vector-Jacobian product of fun.

grad() is implemented as a special case of vjp().

Parameters: fun (Callable) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars. primals – A sequence of primal values at which the Jacobian of fun should be evaluated. The length of primals should be equal to the number of positional parameters to fun. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof. has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False. If has_aux is False, returns a (primals_out, vjpfun) pair, where primals_out is fun(*primals). vjpfun is a function from a cotangent vector with the same shape as primals_out to a tuple of cotangent vectors with the same shape as primals, representing the vector-Jacobian product of fun evaluated at primals. If has_aux is True, returns a (primals_out, vjpfun, aux) tuple where aux is the auxiliary data returned by fun.
>>> def f(x, y):
...   return jax.numpy.sin(x), jax.numpy.cos(y)
...
>>> primals, f_vjp = jax.vjp(f, 0.5, 1.0)
>>> xbar, ybar = f_vjp((-0.7, 0.3))
>>> print(xbar)
-0.61430776
>>> print(ybar)
-0.2524413

jax.custom_jvp(fun, nondiff_argnums=())[source]

Set up a JAX-transformable function for a custom JVP rule definition.

This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like jax.jvp or jax.grad) is applied, in which case a custom user-supplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method, defjvp, which defines the custom JVP rule.

For example:

import jax.numpy as np

@jax.custom_jvp
def f(x, y):
return np.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = np.cos(x) * x_dot * y - np.sin(x) * y_dot
return primal_out, tangent_out


For a more detailed introduction, see the tutorial.

jax.custom_vjp(fun, nondiff_argnums=())[source]

Set up a JAX-transformable function for a custom VJP rule definition.

This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reverse-mode differentiation transformation (like jax.grad) is applied, in which case a custom user-supplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method, defvjp, which defines the custom VJP rule.

This decorator precludes the use of forward-mode automatic differentiation.

For example:

import jax.numpy as np

@jax.custom_vjp
def f(x, y):
return np.sin(x) * y

def f_fwd(x, y):
return f(x, y), (np.cos(x), np.sin(x), y)

def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, -sin_x * g)

f.defvjp(f_fwd, f_bwd)


For a more detailed introduction, see the tutorial.

## Vectorization (vmap)¶

jax.vmap(fun, in_axes=0, out_axes=0)[source]

Vectorizing map. Creates a function which maps fun over argument axes.

Parameters: fun (Callable) – Function to be mapped over additional axes. in_axes – A nonnegative integer, None, or (nested) standard Python container (tuple/list/dict) thereof specifying which input array axes to map over. If each positional argument to fun is an array, then in_axes can be a nonnegative integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments to fun. An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. If the positional arguments to fun are container types, the corresponding element of in_axes can itself be a matching container, so that distinct array axes can be mapped for different container elements. in_axes must be a container tree prefix of the positional argument tuple passed to fun. out_axes – A nonnegative integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output. Callable Batched/vectorized version of fun with arguments that correspond to those of fun, but with extra array axes at positions indicated by in_axes, and a return value that corresponds to that of fun, but with extra array axes at positions indicated by out_axes.

For example, we can implement a matrix-matrix product using a vector dot product:

>>> vv = lambda x, y: np.vdot(x, y)  #  ([a], [a]) -> []
>>> mv = vmap(vv, (0, None), 0)      #  ([b,a], [a]) -> [b]      (b is the mapped axis)
>>> mm = vmap(mv, (None, 1), 1)      #  ([b,a], [a,c]) -> [b,c]  (c is the mapped axis)


Here we use [a,b] to indicate an array with shape (a,b). Here are some variants:

>>> mv1 = vmap(vv, (0, 0), 0)   #  ([b,a], [b,a]) -> [b]        (b is the mapped axis)
>>> mv2 = vmap(vv, (0, 1), 0)   #  ([b,a], [a,b]) -> [b]        (b is the mapped axis)
>>> mm2 = vmap(mv2, (1, 1), 0)  #  ([b,c,a], [a,c,b]) -> [c,b]  (c is the mapped axis)


Here’s an example of using container types in in_axes to specify which axes of the container elements to map over:

>>> A, B, C, D = 2, 3, 4, 5
>>> x = np.ones((A, B))
>>> y = np.ones((B, C))
>>> z = np.ones((C, D))
>>> def foo(tree_arg):
...   x, (y, z) = tree_arg
...   return np.dot(x, np.dot(y, z))
>>> tree = (x, (y, z))
>>> print(foo(tree))
[[12. 12. 12. 12. 12.]
[12. 12. 12. 12. 12.]]
>>> from jax import vmap
>>> K = 6  # batch size
>>> x = np.ones((K, A, B))  # batch axis in different locations
>>> y = np.ones((B, K, C))
>>> z = np.ones((C, D, K))
>>> tree = (x, (y, z))
>>> vfoo = vmap(foo, in_axes=((0, (1, 2)),))
>>> print(vfoo(tree)).shape
(6, 2, 5)

jax.numpy.vectorize(pyfunc, *, excluded=frozenset(), signature=None)[source]

Generalized function class.

LAX-backend implementation of vectorize().

JAX’s implementation of vectorize should be considerably more efficient than NumPy’s, because it uses a batching transformation rather than an explicit “for” loop.

Note that JAX only supports the optional excluded (integer only) and signature arguments, both of which must be specified with keywords.

Original docstring below.

vectorize(pyfunc, otypes=None, doc=None, excluded=None, cache=False,
signature=None)

Define a vectorized function which takes a nested sequence of objects or numpy arrays as inputs and returns a single numpy array or a tuple of numpy arrays. The vectorized function evaluates pyfunc over successive tuples of the input arrays like the python map function, except it uses the broadcasting rules of numpy.

The data type of the output of vectorized is determined by calling the function with the first element of the input. This can be avoided by specifying the otypes argument.

Returns
vectorized : callable
Vectorized function.

frompyfunc : Takes an arbitrary Python function and returns a ufunc

The vectorize function is provided primarily for convenience, not for performance. The implementation is essentially a for loop.

If otypes is not specified, then a call to the function with the first argument will be used to determine the number of outputs. The results of this call will be cached if cache is True to prevent calling the function twice. However, to implement the cache, the original function must be wrapped which will slow down subsequent calls, so only do this if your function is expensive.

The new keyword argument interface and excluded argument support further degrades performance.

 [1] NumPy Reference, section Generalized Universal Function API.
>>> def myfunc(a, b):
...     "Return a-b if a>b, otherwise return a+b"
...     if a > b:
...         return a - b
...     else:
...         return a + b

>>> vfunc = np.vectorize(myfunc)
>>> vfunc([1, 2, 3, 4], 2)
array([3, 4, 1, 2])


The docstring is taken from the input function to vectorize unless it is specified:

>>> vfunc.__doc__
'Return a-b if a>b, otherwise return a+b'
>>> vfunc = np.vectorize(myfunc, doc='Vectorized myfunc')
>>> vfunc.__doc__
'Vectorized myfunc'


The output type is determined by evaluating the first element of the input, unless it is specified:

>>> out = vfunc([1, 2, 3, 4], 2)
>>> type(out[0])
<class 'numpy.int64'>
>>> vfunc = np.vectorize(myfunc, otypes=[float])
>>> out = vfunc([1, 2, 3, 4], 2)
>>> type(out[0])
<class 'numpy.float64'>


The excluded argument can be used to prevent vectorizing over certain arguments. This can be useful for array-like arguments of a fixed length such as the coefficients for a polynomial as in polyval:

>>> def mypolyval(p, x):
...     _p = list(p)
...     res = _p.pop(0)
...     while _p:
...         res = res*x + _p.pop(0)
...     return res
>>> vpolyval = np.vectorize(mypolyval, excluded=['p'])
>>> vpolyval(p=[1, 2, 3], x=[0, 1])
array([3, 6])


Positional arguments may also be excluded by specifying their position:

>>> vpolyval.excluded.add(0)
>>> vpolyval([1, 2, 3], x=[0, 1])
array([3, 6])


The signature argument allows for vectorizing functions that act on non-scalar arrays of fixed length. For example, you can use it for a vectorized calculation of Pearson correlation coefficient and its p-value:

>>> import scipy.stats
>>> pearsonr = np.vectorize(scipy.stats.pearsonr,
...                 signature='(n),(n)->(),()')
>>> pearsonr([[0, 1, 2, 3]], [[1, 2, 3, 4], [4, 3, 2, 1]])
(array([ 1., -1.]), array([ 0.,  0.]))


Or for a vectorized convolution:

>>> convolve = np.vectorize(np.convolve, signature='(n),(m)->(k)')
>>> convolve(np.eye(4), [1, 2, 1])
array([[1., 2., 1., 0., 0., 0.],
[0., 1., 2., 1., 0., 0.],
[0., 0., 1., 2., 1., 0.],
[0., 0., 0., 1., 2., 1.]])


## Parallelization (pmap)¶

jax.pmap(fun, axis_name=None, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None)[source]

Parallel map with support for collectives.

The purpose of pmap is to express single-program multiple-data (SPMD) programs. Applying pmap to a function will compile the function with XLA (similarly to jit), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable to vmap because both transformations map a function over array axes, but where vmap vectorizes functions by pushing the mapped axis down into primitive operations, pmap instead replicates the function and executes each replica on its own XLA device in parallel.

Another key difference with vmap is that while vmap can only express pure maps, pmap enables the use of parallel SPMD collective operations, like all-reduce sum.

The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by jax.local_device_count() (unless devices is specified, see below). For nested pmap calls, the product of the mapped axis sizes must be less than or equal to the number of XLA devices.

Multi-host platforms: On multi-host platforms such as TPU pods, pmap is designed to be used in SPMD Python programs, where every host is running the same Python code such that all hosts run the same pmapped function in the same order. Each host should still call the pmapped function with mapped axis size equal to the number of local devices (unless devices is specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations in fun will be computed over all participating devices, including those on other hosts, via device-to-device communication. Conceptually, this can be thought of as running a pmap over a single array sharded across hosts, where each host “sees” only its local shard of the input and output. The SPMD model requires that the same multi-host pmaps must be run in the same order on all devices, but they can be interspersed with arbitrary operations running on a single host.

Parameters: fun (Callable) – Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof. axis_name (Optional[Any]) – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied. static_broadcasted_argnums (Union[int, Iterable[int]]) – An int or collection of ints specifying which positional arguments to treat as static (compile-time constant). Operations that only depend on static arguments will be constant-folded. Calling the pmaped function with different values for these constants will trigger recompilation. If the pmaped function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Each of the static arguments will be broadcasted to all devices. Defaults to (). devices – This is an experimental feature and the API is likely to change. Optional, a sequence of Devices to map over. (Available devices can be retrieved via jax.devices()). If specified, the size of the mapped axis must be equal to the number of local devices in the sequence. Nested pmap s with devices specified in either the inner or outer pmap are not yet supported. backend (Optional[str]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. Callable A parallelized version of fun with arguments that correspond to those of fun but each with an additional leading array axis (with equal sizes) and with output that has an additional leading array axis (with the same size).

For example, assuming 8 XLA devices are available, pmap can be used as a map along a leading array axes:

>>> out = pmap(lambda x: x ** 2)(np.arange(8))
>>> print(out)
[0, 1, 4, 9, 16, 25, 36, 49]
>>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2))
>>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2
>>> out = pmap(np.dot)(x, y)
>>> print(out)
[[[    4.     9.]
[   12.    29.]]
[[  244.   345.]
[  348.   493.]]
[[ 1412.  1737.]
[ 1740.  2141.]]]


In addition to expressing pure maps, pmap can also be used to express parallel single-program multiple-data (SPMD) programs that communicate via collective operations. For example:

>>> f = lambda x: x / jax.lax.psum(x, axis_name='i')
>>> out = pmap(f, axis_name='i')(np.arange(4.))
>>> print(out)
[ 0.          0.16666667  0.33333334  0.5       ]
>>> print(out.sum())
1.0


In this example, axis_name is a string, but it can be any Python object with __hash__ and __eq__ defined.

The argument axis_name to pmap names the mapped axis so that collective operations, like jax.lax.psum, can refer to it. Axis names are important particularly in the case of nested pmap functions, where collectives can operate over distinct axes:

>>> from functools import partial
>>> @partial(pmap, axis_name='rows')
>>> @partial(pmap, axis_name='cols')
>>> def normalize(x):
>>>   row_normed = x / jax.lax.psum(x, 'rows')
>>>   col_normed = x / jax.lax.psum(x, 'cols')
>>>   doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
>>>   return row_normed, col_normed, doubly_normed
>>>
>>> x = np.arange(8.).reshape((4, 2))
>>> row_normed, col_normed, doubly_normed = normalize(x)
>>> print(row_normed.sum(0))
[ 1.  1.]
>>> print(col_normed.sum(1))
[ 1.  1.  1.  1.]
>>> print(doubly_normed.sum((0, 1)))
1.0


On multi-host platforms, collective operations operate over all devices, including those those on other hosts. For example, assuming the following code runs on two hosts with 4 XLA devices each:

>>> f = lambda x: x + jax.lax.psum(x, axis_name='i')
>>> data = np.arange(4) if jax.host_id() == 0 else np.arange(4,8)
>>> out = pmap(f, axis_name='i')(data)
>>> print(out)
[28 29 30 31] # on host 0
[32 33 34 35] # on host 1


Each host passes in a different length-4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length-4 arrays can be thought of as sharded length-16 array (in this example equivalent to np.arange(8)) that is mapped over, with the length-8 mapped axis given name ‘i’. The pmap call on each host then returns the corresponding length-4 output shard.

The devices argument can be used to specify exactly which devices are used to run the parallel computation. For example, again assuming a single host with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two:

>>> from functools import partial
>>> @partial(pmap, axis_name='i', devices=jax.devices()[:6])
>>> def f1(x):
>>>   return x / jax.lax.psum(x, axis_name='i')
>>>
>>> @partial(pmap, axis_name='i', devices=jax.devices()[-2:])
>>> def f2(x):
>>>   return jax.lax.psum(x ** 2, axis_name='i')
>>>
>>> print(f1(np.arange(6.)))
[0.         0.06666667 0.13333333 0.2        0.26666667 0.33333333]
>>> print(f2(np.array([2., 3.])))
[ 13.  13.]

jax.devices(backend=None)[source]

Returns a list of all devices.

Each device is represented by a subclass of Device (e.g. CpuDevice, GpuDevice). The length of the returned list is equal to device_count(). Local devices can be identified by comparing Device.host_id to host_id().

Parameters: backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. List of Device subclasses.
jax.local_devices(host_id=None, backend=None)[source]

Returns a list of devices local to a given host (this host by default).

jax.host_id(backend=None)[source]

Returns the integer host ID of this host.

On most platforms, this will always be 0. This will vary on multi-host platforms though.

Parameters: backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. Integer host ID.
jax.host_ids(backend=None)[source]

Returns a sorted list of all host IDs.

jax.device_count(backend=None)[source]

Returns the total number of devices.

On most platforms, this is the same as local_device_count(). However, on multi-host platforms, this will return the total number of devices across all hosts.

Parameters: backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. Number of devices.
jax.local_device_count(backend=None)[source]

Returns the number of devices on this host.

jax.host_count(backend=None)[source]

Returns the number of hosts.