Public API: jax package¶
Subpackages¶
Justintime compilation (jit
)¶

jax.
jit
(fun, static_argnums=(), device=None, backend=None)[source]¶ Sets up fun for justintime compilation with XLA.
Parameters:  fun (
Callable
) –Function to be jitted. Should be a pure function, as sideeffects may only be executed once. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.
Positional arguments indicated by static_argnums can be anything at all, provided they are hashable and have an equality operation defined. Static arguments are included as part of a compilation cache key, which is why hash and equality operators must be defined.
 static_argnums (
Union
[int
,Iterable
[int
]]) – An int or collection of ints specifying which positional arguments to treat as static (compiletime constant). Operations that only depend on static arguments will be constantfolded. Calling the jitted function with different values for these constants will trigger recompilation. If the jitted function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Defaults to ().  device – This is an experimental feature and the API is likely to change.
Optional, the Device the jitted function will run on. (Available devices
can be retrieved via
jax.devices()
.) The default is inherited from XLA’s DeviceAssignment logic and is usually to usejax.devices()[0]
.  backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’,’gpu’, or ‘tpu’.
Return type: Returns: A wrapped version of fun, set up for justintime compilation.
In the following example, selu can be compiled into a single fused kernel by XLA:
>>> @jax.jit >>> def selu(x, alpha=1.67, lmbda=1.05): >>> return lmbda * jax.numpy.where(x > 0, x, alpha * jax.numpy.exp(x)  alpha) >>> >>> key = jax.random.PRNGKey(0) >>> x = jax.random.normal(key, (10,)) >>> print(selu(x)) [0.54485154 0.27744263 0.29255125 0.91421586 0.62452525 0.2474813 0.8574326 0.7823267 0.7682731 0.59566754]
 fun (

jax.
disable_jit
()[source]¶ Context manager that disables jit behavior under its dynamic context.
For debugging purposes, it is useful to have a mechanism that disables jit everywhere in a dynamic context.
Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a ShapedArray instance, representing the set of all possible arrays with a given shape and dtype, but not representing one concrete array with specific values. You might notice those if you use a benign sideeffecting operation in a jitted function, like a print:
>>> @jax.jit >>> def f(x): ... y = x * 2 ... print("Value of y is", y) ... return y + 3 ... >>> print(f(jax.numpy.array([1, 2, 3]))) Value of y is Traced<ShapedArray(int32[3]):JaxprTrace(level=1/1)> [5 7 9]
Here y has been abstracted by jit to a ShapedArray, which represents an array with a fixed shape and type but an arbitrary value. It’s also traced. If we want to see a concrete value while debugging, and avoid the tracer too, we can use the disable_jit context manager:
>>> with jax.disable_jit(): >>> print(f(np.array([1, 2, 3]))) >>> Value of y is [2 4 6] [5 7 9]

jax.
xla_computation
(fun, static_argnums=(), axis_env=None, backend=None, tuple_args=False, instantiate_const_outputs=True)[source]¶ Creates a function that produces its XLA computation given example args.
Parameters:  fun (
Callable
) – Function from which to form XLA computations.  static_argnums (
Union
[int
,Iterable
[int
]]) – See thejax.jit
docstring.  axis_env (
Optional
[Sequence
[Tuple
[Any
,int
]]]) – Optional, a sequence of pairs where the first element is an axis name and the second element is a positive integer representing the size of the mapped axis with that name. This parameter is useful when lowering functions that involve parallel communication collectives, and it specifies the axis name/size environment that would be set up by applications ofjax.pmap
. See the examples below.  backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’,’gpu’, or ‘tpu’.  tuple_args (
bool
) – Optional bool, defaults to False. If True, the resulting XLA computation will have a single tuple argument that is unpacked into the specified function arguments.  instantiate_const_outputs (
bool
) – Optional bool, defaults to True. If False, thenxla_computation
does not instantiate constantvalued outputs in the XLA computation, and so the result is closer to the computation thatjax.jit
produces and may be more useful for studyingjit
behavior. If True, then constantvalued outputs are instantiated in the XLA computation, which may be more useful for staging computations out of JAX entirely.
Return type: Returns: A wrapped version of
fun
that when applied to example arguments returns a built XLA Computation (see xla_client.py), from which representations of the unoptimized XLA HLO computation can be extracted using methods likeGetHloText
,GetSerializedProto
, andGetHloDotGraph
.For example:
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> c = jax.xla_computation(f)(3.) >>> print(c.GetHloText()) HloModule jaxpr_computation__4.5 ENTRY jaxpr_computation__4.5 { tuple.1 = () tuple() parameter.2 = f32[] parameter(0) cosine.3 = f32[] cosine(parameter.2) ROOT sine.4 = f32[] sine(cosine.3) }
Here’s an example that involves a parallel collective and axis name:
>>> def f(x): return x  jax.lax.psum(x, 'i') >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) >>> print(c.GetHloText()) HloModule jaxpr_computation.9 primitive_computation.3 { parameter.4 = s32[] parameter(0) parameter.5 = s32[] parameter(1) ROOT add.6 = s32[] add(parameter.4, parameter.5) } ENTRY jaxpr_computation.9 { tuple.1 = () tuple() parameter.2 = s32[] parameter(0) allreduce.7 = s32[] allreduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 ROOT subtract.8 = s32[] subtract(parameter.2, allreduce.7) }
Notice the
replica_groups
that were generated. Here’s an example that generates more interestingreplica_groups
:>>> def g(x): ... rowsum = lax.psum(x, 'i') ... colsum = lax.psum(x, 'j') ... allsum = lax.psum(x, ('i', 'j')) ... return rowsum, colsum, allsum ... >>> axis_env = [('i', 4), ('j', 2)] >>> c = xla_computation(g, axis_env=axis_env)(5.) >>> print(c.GetHloText()) HloModule jaxpr_computation__1.19 [removed uninteresting text here] ENTRY jaxpr_computation__1.19 { tuple.1 = () tuple() parameter.2 = f32[] parameter(0) allreduce.7 = f32[] allreduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 allreduce.12 = f32[] allreduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 allreduce.17 = f32[] allreduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 ROOT tuple.18 = (f32[], f32[], f32[]) tuple(allreduce.7, allreduce.12, allreduce.17) }
 fun (

jax.
make_jaxpr
(fun)[source]¶ Creates a function that produces its jaxpr given example args.
Parameters: fun ( Callable
) – The function whose jaxpr is to be computed. Its positional arguments and return value should be arrays, scalars, or standard Python containers (tuple/list/dict) thereof.Return type: Callable
[…,TypedJaxpr
]Returns: A wrapped version of fun that when applied to example arguments returns a TypedJaxpr representation of fun on those arguments. A jaxpr is JAX’s intermediate representation for program traces. The jaxpr language is based on the simplytyped firstorder lambda calculus with letbindings. make_jaxpr adapts a function to return its jaxpr, which we can inspect to understand what JAX is doing internally.
The jaxpr returned is a trace of fun abstracted to ShapedArray level. Other levels of abstraction exist internally.
We do not describe the semantics of the jaxpr language in detail here, but instead give a few examples.
>>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) >>> print(f(3.0)) 0.83602184 >>> jax.make_jaxpr(f)(3.0) { lambda ; ; a. let b = cos a c = sin b in [c] } >>> jax.make_jaxpr(jax.grad(f))(3.0) { lambda ; ; a. let b = cos a c = cos b d = mul 1.0 c e = neg d f = sin a g = mul e f in [g] }

jax.
eval_shape
(fun, *args, **kwargs)[source]¶ Compute the shape/dtype of
fun(*args, **kwargs)
without any FLOPs.This utility function is useful for performing shape inference. Its input/output behavior is defined by:
def eval_shape(fun, *args, **kwargs): out = fun(*args, **kwargs) return jax.tree_util.tree_map(shape_dtype_struct, out) def shape_dtype_struct(x): return ShapeDtypeStruct(x.shape, x.dtype) class ShapeDtypeStruct(object): __slots__ = ["shape", "dtype"] def __init__(self, shape, dtype): self.shape = shape self.dtype = dtype
In particular, the output is a pytree of objects that have
shape
anddtype
attributes, but nothing else about them is guaranteed by the API.But instead of applying
fun
directly, which might be expensive, it uses JAX’s abstract interpretation machinery to evaluate the shapes without doing any FLOPs.Using
eval_shape
can also catch shape errors, and will raise same shape errors as evaluatingfun(*args, **kwargs)
.Parameters:  *args – a positional argument tuple of arrays, scalars, or (nested) standard
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
those types. Since only the
shape
anddtype
attributes are accessed, only values that ducktype arrays are required, rather than real ndarrays. The ducktyped objects cannot be namedtuples because those are treated as standard Python containers. See the example below.  **kwargs – a keyword argument dict of arrays, scalars, or (nested) standard
Python containers (pytrees) of those types. As in
args
, array values need only be ducktyped to haveshape
anddtype
attributes.
For example:
>>> f = lambda A, x: np.tanh(np.dot(A, x)) >>> class MyArgArray(object): ... def __init__(self, shape, dtype): ... self.shape = shape ... self.dtype = dtype ... >>> A = MyArgArray((2000, 3000), np.float32) >>> x = MyArgArray((3000, 1000), np.float32) >>> out = jax.eval_shape(f, A, x) # no FLOPs performed >>> print(out.shape) (2000, 1000) >>> print(out.dtype) dtype('float32')
 *args – a positional argument tuple of arrays, scalars, or (nested) standard
Python containers (tuples, lists, dicts, namedtuples, i.e. pytrees) of
those types. Since only the
Automatic differentiation¶

jax.
grad
(fun, argnums=0, has_aux=False, holomorphic=False)[source]¶ Creates a function which evaluates the gradient of fun.
Parameters:  fun (
Callable
) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)  argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).  has_aux (
bool
) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.  holomorphic (
bool
) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Return type: Returns: A function with the same arguments as fun, that evaluates the gradient of fun. If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a tuple of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments. If has_aux is True then a pair of (gradient, auxiliary_data) is returned.
For example:
>>> grad_tanh = jax.grad(jax.numpy.tanh) >>> print(grad_tanh(0.2)) 0.961043
 fun (

jax.
value_and_grad
(fun, argnums=0, has_aux=False, holomorphic=False)[source]¶ Creates a function which evaluates both fun and the gradient of fun.
Parameters:  fun (
Callable
) – Function to be differentiated. Its arguments at positions specified by argnums should be arrays, scalars, or standard Python containers. It should return a scalar (which includes arrays with shape () but not arrays with shape (1,) etc.)  argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).  has_aux (
bool
) – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.  holomorphic (
bool
) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Return type: Returns: A function with the same arguments as fun that evaluates both fun and the gradient of fun and returns them as a pair (a twoelement tuple). If argnums is an integer then the gradient has the same shape and type as the positional argument indicated by that integer. If argnums is a sequence of integers, the gradient is a tuple of values with the same shapes and types as the corresponding arguments.
 fun (

jax.
jacfwd
(fun, argnums=0, holomorphic=False)[source]¶ Jacobian of fun evaluated columnbycolumn using forwardmode AD.
Parameters:  fun (
Callable
) – Function whose Jacobian is to be computed.  argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).  holomorphic (
bool
) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Return type: Returns: A function with the same arguments as fun, that evaluates the Jacobian of fun using forwardmode automatic differentiation.
>>> def f(x): ... return jax.numpy.asarray( ... [x[0], 5*x[2], 4*x[1]**2  2*x[2], x[2] * jax.numpy.sin(x[0])]) ... >>> print(jax.jacfwd(f)(np.array([1., 2., 3.]))) [[ 1. , 0. , 0. ], [ 0. , 0. , 5. ], [ 0. , 16. , 2. ], [ 1.6209068 , 0. , 0.84147096]]
 fun (

jax.
jacrev
(fun, argnums=0, holomorphic=False)[source]¶ Jacobian of fun evaluated rowbyrow using reversemode AD.
Parameters:  fun (
Callable
) – Function whose Jacobian is to be computed.  argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).  holomorphic (
bool
) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Return type: Returns: A function with the same arguments as fun, that evaluates the Jacobian of fun using reversemode automatic differentiation.
>>> def f(x): ... return jax.numpy.asarray( ... [x[0], 5*x[2], 4*x[1]**2  2*x[2], x[2] * jax.numpy.sin(x[0])]) ... >>> print(jax.jacrev(f)(np.array([1., 2., 3.]))) [[ 1. , 0. , 0. ], [ 0. , 0. , 5. ], [ 0. , 16. , 2. ], [ 1.6209068 , 0. , 0.84147096]]
 fun (

jax.
hessian
(fun, argnums=0, holomorphic=False)[source]¶ Hessian of fun.
Parameters:  fun (
Callable
) – Function whose Hessian is to be computed.  argnums (
Union
[int
,Sequence
[int
]]) – Optional, integer or sequence of integers. Specifies which positional argument(s) to differentiate with respect to (default 0).  holomorphic (
bool
) – Optional, bool. Indicates whether fun is promised to be holomorphic. Default False.
Return type: Returns: A function with the same arguments as fun, that evaluates the Hessian of fun.
>>> g = lambda(x): x[0]**3  2*x[0]*x[1]  x[1]**6 >>> print(jax.hessian(g)(jax.numpy.array([1., 2.]))) [[ 6., 2.], [ 2., 480.]]
 fun (

jax.
jvp
(fun, primals, tangents)[source]¶ Computes a (forwardmode) Jacobianvector product of fun.
Parameters:  fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.  primals – The primal values at which the Jacobian of fun should be evaluated. Should be either a tuple or a list of arguments, and its length should equal to the number of positional parameters of fun.
 tangents – The tangent vector for which the Jacobianvector product should be evaluated. Should be either a tuple or a list of tangents, with the same tree structure and array shapes as primals.
Return type: Returns: A (primals_out, tangents_out) pair, where primals_out is fun(*primals), and tangents_out is the Jacobianvector product of function evaluated at primals with tangents. The tangents_out value has the same Python tree structure and shapes as primals_out.
For example:
>>> y, v = jax.jvp(jax.numpy.sin, (0.1,), (0.2,)) >>> print(y) 0.09983342 >>> print(v) 0.19900084
 fun (

jax.
linearize
(fun, *primals)[source]¶ Produce a linear approximation to fun using jvp and partial evaluation.
Parameters:  fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard python container of arrays or scalars.  primals – The primal values at which the Jacobian of fun should be evaluated. Should be a tuple of arrays, scalar, or standard Python container thereof. The length of the tuple is equal to the number of positional parameters of fun.
Return type: Returns: A pair where the first element is the value of f(*primals) and the second element is a function that evaluates the (forwardmode) Jacobianvector product of fun evaluated at primals without redoing the linearization work.
In terms of values computed, linearize behaves much like a curried jvp, where these two code blocks compute the same values:
y, out_tangent = jax.jvp(f, (x,), (in_tangent,)) y, f_jvp = jax.linearize(f, x) out_tangent = f_jvp(in_tangent)
However, the difference is that linearize uses partial evaluation so that the function f is not relinearized on calls to f_jvp. In general that means the memory usage scales with the size of the computation, much like in reversemode. (Indeed, linearize has a similar signature to vjp!)
This function is mainly useful if you want to apply f_jvp multiple times, i.e. to evaluate a pushforward for many different input tangent vectors at the same linearization point. Moreover if all the input tangent vectors are known at once, it can be more efficient to vectorize using vmap, as in:
pushfwd = partial(jvp, f, (x,)) y, out_tangents = vmap(pushfwd, out_axes=(None, 0))((in_tangents,))
By using vmap and jvp together like this we avoid the storedlinearization memory cost that scales with the depth of the computation, which is incurred by both linearize and vjp.
Here’s a more complete example of using linearize:
>>> def f(x): return 3. * np.sin(x) + np.cos(x / 2.) ... >>> jax.jvp(f, (2.,), (3.,)) (array(3.2681944, dtype=float32), array(5.007528, dtype=float32)) >>> y, f_jvp = jax.linearize(f, 2.) >>> print(y) 3.2681944 >>> print(f_jvp(3.)) 5.007528 >>> print(f_jvp(4.)) 6.676704
 fun (

jax.
vjp
(fun, *primals, **kwargs)[source]¶ Compute a (reversemode) vectorJacobian product of fun.
grad()
is implemented as a special case ofvjp()
.Parameters:  fun (
Callable
) – Function to be differentiated. Its arguments should be arrays, scalars, or standard Python containers of arrays or scalars. It should return an array, scalar, or standard Python container of arrays or scalars.  primals – A sequence of primal values at which the Jacobian of fun should be evaluated. The length of primals should be equal to the number of positional parameters to fun. Each primal value should be a tuple of arrays, scalar, or standard Python containers thereof.
 has_aux – Optional, bool. Indicates whether fun returns a pair where the first element is considered the output of the mathematical function to be differentiated and the second element is auxiliary data. Default False.
Return type: Returns: If
has_aux
isFalse
, returns a(primals_out, vjpfun)
pair, whereprimals_out
isfun(*primals)
.vjpfun
is a function from a cotangent vector with the same shape asprimals_out
to a tuple of cotangent vectors with the same shape asprimals
, representing the vectorJacobian product offun
evaluated atprimals
. Ifhas_aux
isTrue
, returns a(primals_out, vjpfun, aux)
tuple whereaux
is the auxiliary data returned byfun
.>>> def f(x, y): ... return jax.numpy.sin(x), jax.numpy.cos(y) ... >>> primals, f_vjp = jax.vjp(f, 0.5, 1.0) >>> xbar, ybar = f_vjp((0.7, 0.3)) >>> print(xbar) 0.61430776 >>> print(ybar) 0.2524413
 fun (

jax.
custom_jvp
(fun, nondiff_argnums=())[source]¶ Set up a JAXtransformable function for a custom JVP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a differentiation transformation (like
jax.jvp
orjax.grad
) is applied, in which case a custom usersupplied JVP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method,defjvp
, which defines the custom JVP rule.For example:
import jax.numpy as np @jax.custom_jvp def f(x, y): return np.sin(x) * y @f.defjvp def f_jvp(primals, tangents): x, y = primals x_dot, y_dot = tangents primal_out = f(x, y) tangent_out = np.cos(x) * x_dot * y  np.sin(x) * y_dot return primal_out, tangent_out
For a more detailed introduction, see the tutorial.

jax.
custom_vjp
(fun, nondiff_argnums=())[source]¶ Set up a JAXtransformable function for a custom VJP rule definition.
This class is meant to be used as a function decorator. Instances are callables that behave similarly to the underlying function to which the decorator was applied, except when a reversemode differentiation transformation (like
jax.grad
) is applied, in which case a custom usersupplied VJP rule function is used instead of tracing into and performing automatic differentiation of the underlying function’s implementation. There is a single instance method,defvjp
, which defines the custom VJP rule.This decorator precludes the use of forwardmode automatic differentiation.
For example:
import jax.numpy as np @jax.custom_vjp def f(x, y): return np.sin(x) * y def f_fwd(x, y): return f(x, y), (np.cos(x), np.sin(x), y) def f_bwd(res, g): cos_x, sin_x, y = res return (cos_x * g * y, sin_x * g) f.defvjp(f_fwd, f_bwd)
For a more detailed introduction, see the tutorial.
Vectorization (vmap
)¶

jax.
vmap
(fun, in_axes=0, out_axes=0)[source]¶ Vectorizing map. Creates a function which maps fun over argument axes.
Parameters:  fun (
Callable
) – Function to be mapped over additional axes.  in_axes – A nonnegative integer, None, or (nested) standard Python container
(tuple/list/dict) thereof specifying which input array axes to map over.
If each positional argument to
fun
is an array, thenin_axes
can be a nonnegative integer, a None, or a tuple of integers and Nones with length equal to the number of positional arguments tofun
. An integer or None indicates which array axis to map over for all arguments (with None indicating not to map any axis), and a tuple indicates which axis to map for each corresponding positional argument. If the positional arguments tofun
are container types, the corresponding element ofin_axes
can itself be a matching container, so that distinct array axes can be mapped for different container elements.in_axes
must be a container tree prefix of the positional argument tuple passed tofun
.  out_axes – A nonnegative integer, None, or (nested) standard Python container (tuple/list/dict) thereof indicating where the mapped axis should appear in the output.
Return type: Returns: Batched/vectorized version of
fun
with arguments that correspond to those offun
, but with extra array axes at positions indicated byin_axes
, and a return value that corresponds to that offun
, but with extra array axes at positions indicated byout_axes
.For example, we can implement a matrixmatrix product using a vector dot product:
>>> vv = lambda x, y: np.vdot(x, y) # ([a], [a]) > [] >>> mv = vmap(vv, (0, None), 0) # ([b,a], [a]) > [b] (b is the mapped axis) >>> mm = vmap(mv, (None, 1), 1) # ([b,a], [a,c]) > [b,c] (c is the mapped axis)
Here we use
[a,b]
to indicate an array with shape (a,b). Here are some variants:>>> mv1 = vmap(vv, (0, 0), 0) # ([b,a], [b,a]) > [b] (b is the mapped axis) >>> mv2 = vmap(vv, (0, 1), 0) # ([b,a], [a,b]) > [b] (b is the mapped axis) >>> mm2 = vmap(mv2, (1, 1), 0) # ([b,c,a], [a,c,b]) > [c,b] (c is the mapped axis)
Here’s an example of using container types in
in_axes
to specify which axes of the container elements to map over:>>> A, B, C, D = 2, 3, 4, 5 >>> x = np.ones((A, B)) >>> y = np.ones((B, C)) >>> z = np.ones((C, D)) >>> def foo(tree_arg): ... x, (y, z) = tree_arg ... return np.dot(x, np.dot(y, z)) >>> tree = (x, (y, z)) >>> print(foo(tree)) [[12. 12. 12. 12. 12.] [12. 12. 12. 12. 12.]] >>> from jax import vmap >>> K = 6 # batch size >>> x = np.ones((K, A, B)) # batch axis in different locations >>> y = np.ones((B, K, C)) >>> z = np.ones((C, D, K)) >>> tree = (x, (y, z)) >>> vfoo = vmap(foo, in_axes=((0, (1, 2)),)) >>> print(vfoo(tree)).shape (6, 2, 5)
 fun (

jax.numpy.
vectorize
(pyfunc, *, excluded=frozenset(), signature=None)[source]¶ Generalized function class.
LAXbackend implementation of
vectorize()
.JAX’s implementation of vectorize should be considerably more efficient than NumPy’s, because it uses a batching transformation rather than an explicit “for” loop.
Note that JAX only supports the optional
excluded
(integer only) andsignature
arguments, both of which must be specified with keywords.Original docstring below.
 vectorize(pyfunc, otypes=None, doc=None, excluded=None, cache=False,
 signature=None)
Define a vectorized function which takes a nested sequence of objects or numpy arrays as inputs and returns a single numpy array or a tuple of numpy arrays. The vectorized function evaluates pyfunc over successive tuples of the input arrays like the python map function, except it uses the broadcasting rules of numpy.
The data type of the output of vectorized is determined by calling the function with the first element of the input. This can be avoided by specifying the otypes argument.
 Returns
 vectorized : callable
 Vectorized function.
frompyfunc : Takes an arbitrary Python function and returns a ufunc
The vectorize function is provided primarily for convenience, not for performance. The implementation is essentially a for loop.
If otypes is not specified, then a call to the function with the first argument will be used to determine the number of outputs. The results of this call will be cached if cache is True to prevent calling the function twice. However, to implement the cache, the original function must be wrapped which will slow down subsequent calls, so only do this if your function is expensive.
The new keyword argument interface and excluded argument support further degrades performance.
[1] NumPy Reference, section Generalized Universal Function API. >>> def myfunc(a, b): ... "Return ab if a>b, otherwise return a+b" ... if a > b: ... return a  b ... else: ... return a + b
>>> vfunc = np.vectorize(myfunc) >>> vfunc([1, 2, 3, 4], 2) array([3, 4, 1, 2])
The docstring is taken from the input function to vectorize unless it is specified:
>>> vfunc.__doc__ 'Return ab if a>b, otherwise return a+b' >>> vfunc = np.vectorize(myfunc, doc='Vectorized `myfunc`') >>> vfunc.__doc__ 'Vectorized `myfunc`'
The output type is determined by evaluating the first element of the input, unless it is specified:
>>> out = vfunc([1, 2, 3, 4], 2) >>> type(out[0]) <class 'numpy.int64'> >>> vfunc = np.vectorize(myfunc, otypes=[float]) >>> out = vfunc([1, 2, 3, 4], 2) >>> type(out[0]) <class 'numpy.float64'>
The excluded argument can be used to prevent vectorizing over certain arguments. This can be useful for arraylike arguments of a fixed length such as the coefficients for a polynomial as in polyval:
>>> def mypolyval(p, x): ... _p = list(p) ... res = _p.pop(0) ... while _p: ... res = res*x + _p.pop(0) ... return res >>> vpolyval = np.vectorize(mypolyval, excluded=['p']) >>> vpolyval(p=[1, 2, 3], x=[0, 1]) array([3, 6])
Positional arguments may also be excluded by specifying their position:
>>> vpolyval.excluded.add(0) >>> vpolyval([1, 2, 3], x=[0, 1]) array([3, 6])
The signature argument allows for vectorizing functions that act on nonscalar arrays of fixed length. For example, you can use it for a vectorized calculation of Pearson correlation coefficient and its pvalue:
>>> import scipy.stats >>> pearsonr = np.vectorize(scipy.stats.pearsonr, ... signature='(n),(n)>(),()') >>> pearsonr([[0, 1, 2, 3]], [[1, 2, 3, 4], [4, 3, 2, 1]]) (array([ 1., 1.]), array([ 0., 0.]))
Or for a vectorized convolution:
>>> convolve = np.vectorize(np.convolve, signature='(n),(m)>(k)') >>> convolve(np.eye(4), [1, 2, 1]) array([[1., 2., 1., 0., 0., 0.], [0., 1., 2., 1., 0., 0.], [0., 0., 1., 2., 1., 0.], [0., 0., 0., 1., 2., 1.]])
Parallelization (pmap
)¶

jax.
pmap
(fun, axis_name=None, static_broadcasted_argnums=(), devices=None, backend=None, axis_size=None)[source]¶ Parallel map with support for collectives.
The purpose of
pmap
is to express singleprogram multipledata (SPMD) programs. Applyingpmap
to a function will compile the function with XLA (similarly tojit
), then execute it in parallel on XLA devices, such as multiple GPUs or multiple TPU cores. Semantically it is comparable tovmap
because both transformations map a function over array axes, but wherevmap
vectorizes functions by pushing the mapped axis down into primitive operations,pmap
instead replicates the function and executes each replica on its own XLA device in parallel.Another key difference with
vmap
is that whilevmap
can only express pure maps,pmap
enables the use of parallel SPMD collective operations, like allreduce sum.The mapped axis size must be less than or equal to the number of local XLA devices available, as returned by
jax.local_device_count()
(unlessdevices
is specified, see below). For nestedpmap
calls, the product of the mapped axis sizes must be less than or equal to the number of XLA devices.Multihost platforms: On multihost platforms such as TPU pods,
pmap
is designed to be used in SPMD Python programs, where every host is running the same Python code such that all hosts run the same pmapped function in the same order. Each host should still call the pmapped function with mapped axis size equal to the number of local devices (unlessdevices
is specified, see below), and an array of the same leading axis size will be returned as usual. However, any collective operations infun
will be computed over all participating devices, including those on other hosts, via devicetodevice communication. Conceptually, this can be thought of as running a pmap over a single array sharded across hosts, where each host “sees” only its local shard of the input and output. The SPMD model requires that the same multihost pmaps must be run in the same order on all devices, but they can be interspersed with arbitrary operations running on a single host.Parameters:  fun (
Callable
) – Function to be mapped over argument axes. Its arguments and return value should be arrays, scalars, or (nested) standard Python containers (tuple/list/dict) thereof.  axis_name (
Optional
[Any
]) – Optional, a hashable Python object used to identify the mapped axis so that parallel collectives can be applied.  static_broadcasted_argnums (
Union
[int
,Iterable
[int
]]) – An int or collection of ints specifying which positional arguments to treat as static (compiletime constant). Operations that only depend on static arguments will be constantfolded. Calling the pmaped function with different values for these constants will trigger recompilation. If the pmaped function is called with fewer positional arguments than indicated by static_argnums then an error is raised. Each of the static arguments will be broadcasted to all devices. Defaults to ().  devices – This is an experimental feature and the API is likely to change.
Optional, a sequence of Devices to map over. (Available devices can be
retrieved via jax.devices()). If specified, the size of the mapped axis
must be equal to the number of local devices in the sequence. Nested
pmap
s withdevices
specified in either the inner or outerpmap
are not yet supported.  backend (
Optional
[str
]) – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’.
Return type: Returns: A parallelized version of
fun
with arguments that correspond to those offun
but each with an additional leading array axis (with equal sizes) and with output that has an additional leading array axis (with the same size).For example, assuming 8 XLA devices are available,
pmap
can be used as a map along a leading array axes:>>> out = pmap(lambda x: x ** 2)(np.arange(8)) >>> print(out) [0, 1, 4, 9, 16, 25, 36, 49] >>> x = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) >>> y = np.arange(3 * 2 * 2.).reshape((3, 2, 2)) ** 2 >>> out = pmap(np.dot)(x, y) >>> print(out) [[[ 4. 9.] [ 12. 29.]] [[ 244. 345.] [ 348. 493.]] [[ 1412. 1737.] [ 1740. 2141.]]]
In addition to expressing pure maps,
pmap
can also be used to express parallel singleprogram multipledata (SPMD) programs that communicate via collective operations. For example:>>> f = lambda x: x / jax.lax.psum(x, axis_name='i') >>> out = pmap(f, axis_name='i')(np.arange(4.)) >>> print(out) [ 0. 0.16666667 0.33333334 0.5 ] >>> print(out.sum()) 1.0
In this example,
axis_name
is a string, but it can be any Python object with__hash__
and__eq__
defined.The argument
axis_name
topmap
names the mapped axis so that collective operations, likejax.lax.psum
, can refer to it. Axis names are important particularly in the case of nestedpmap
functions, where collectives can operate over distinct axes:>>> from functools import partial >>> @partial(pmap, axis_name='rows') >>> @partial(pmap, axis_name='cols') >>> def normalize(x): >>> row_normed = x / jax.lax.psum(x, 'rows') >>> col_normed = x / jax.lax.psum(x, 'cols') >>> doubly_normed = x / jax.lax.psum(x, ('rows', 'cols')) >>> return row_normed, col_normed, doubly_normed >>> >>> x = np.arange(8.).reshape((4, 2)) >>> row_normed, col_normed, doubly_normed = normalize(x) >>> print(row_normed.sum(0)) [ 1. 1.] >>> print(col_normed.sum(1)) [ 1. 1. 1. 1.] >>> print(doubly_normed.sum((0, 1))) 1.0
On multihost platforms, collective operations operate over all devices, including those those on other hosts. For example, assuming the following code runs on two hosts with 4 XLA devices each:
>>> f = lambda x: x + jax.lax.psum(x, axis_name='i') >>> data = np.arange(4) if jax.host_id() == 0 else np.arange(4,8) >>> out = pmap(f, axis_name='i')(data) >>> print(out) [28 29 30 31] # on host 0 [32 33 34 35] # on host 1
Each host passes in a different length4 array, corresponding to its 4 local devices, and the psum operates over all 8 values. Conceptually, the two length4 arrays can be thought of as sharded length16 array (in this example equivalent to np.arange(8)) that is mapped over, with the length8 mapped axis given name ‘i’. The pmap call on each host then returns the corresponding length4 output shard.
The
devices
argument can be used to specify exactly which devices are used to run the parallel computation. For example, again assuming a single host with 8 devices, the following code defines two parallel computations, one which runs on the first six devices and one on the remaining two:>>> from functools import partial >>> @partial(pmap, axis_name='i', devices=jax.devices()[:6]) >>> def f1(x): >>> return x / jax.lax.psum(x, axis_name='i') >>> >>> @partial(pmap, axis_name='i', devices=jax.devices()[2:]) >>> def f2(x): >>> return jax.lax.psum(x ** 2, axis_name='i') >>> >>> print(f1(np.arange(6.))) [0. 0.06666667 0.13333333 0.2 0.26666667 0.33333333] >>> print(f2(np.array([2., 3.]))) [ 13. 13.]
 fun (

jax.
devices
(backend=None)[source]¶ Returns a list of all devices.
Each device is represented by a subclass of Device (e.g. CpuDevice, GpuDevice). The length of the returned list is equal to
device_count()
. Local devices can be identified by comparingDevice.host_id
tohost_id()
.Parameters: backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. Returns: List of Device subclasses.

jax.
local_devices
(host_id=None, backend=None)[source]¶ Returns a list of devices local to a given host (this host by default).

jax.
host_id
(backend=None)[source]¶ Returns the integer host ID of this host.
On most platforms, this will always be 0. This will vary on multihost platforms though.
Parameters: backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. Returns: Integer host ID.

jax.
device_count
(backend=None)[source]¶ Returns the total number of devices.
On most platforms, this is the same as
local_device_count()
. However, on multihost platforms, this will return the total number of devices across all hosts.Parameters: backend – This is an experimental feature and the API is likely to change. Optional, a string representing the xla backend. ‘cpu’, ‘gpu’, or ‘tpu’. Returns: Number of devices.