jax.numpy.piecewise(x, condlist, funclist, *args, **kw)[source]ΒΆ

Evaluate a piecewise-defined function.

LAX-backend implementation of piecewise(). Unlike np.piecewise, jax.numpy.piecewise() requires functions in funclist to be traceable by JAX, as it is implemeted via jax.lax.switch(). See the jax.lax.switch() documentation for more information.

Original docstring below.

Given a set of conditions and corresponding functions, evaluate each function on the input data wherever its condition is true.

  • x (ndarray or scalar) – The input domain.

  • condlist (list of bool arrays or bool scalars) – Each boolean array corresponds to a function in funclist. Wherever condlist[i] is True, funclist[i](x) is used as the output value.

  • funclist (list of callables, f(x,*args,**kw), or scalars) – Each function is evaluated over x wherever its corresponding condition is True. It should take a 1d array as input and give an 1d array or a scalar value as output. If, instead of a callable, a scalar is provided then a constant function (lambda x: scalar) is assumed.

  • args (tuple, optional) – Any further arguments given to piecewise are passed to the functions upon execution, i.e., if called piecewise(..., ..., 1, 'a'), then each function is called as f(x, 1, 'a').

  • kw (dict, optional) – Keyword arguments used in calling piecewise are passed to the functions upon execution, i.e., if called piecewise(..., ..., alpha=1), then each function is called as f(x, alpha=1).


out – The output is the same shape and type as x and is found by calling the functions in funclist on the appropriate portions of x, as defined by the boolean arrays in condlist. Portions not covered by any condition have a default value of 0.

Return type



This is similar to choose or select, except that functions are evaluated on elements of x that satisfy the corresponding condition from condlist.

The result is:

out = |funclist[1](x[condlist[1]])


Define the sigma function, which is -1 for x < 0 and +1 for x >= 0.

>>> x = np.linspace(-2.5, 2.5, 6)
>>> np.piecewise(x, [x < 0, x >= 0], [-1, 1])
array([-1., -1., -1.,  1.,  1.,  1.])

Define the absolute value, which is -x for x <0 and x for x >= 0.

>>> np.piecewise(x, [x < 0, x >= 0], [lambda x: -x, lambda x: x])
array([2.5,  1.5,  0.5,  0.5,  1.5,  2.5])

Apply the same function to a scalar value.

>>> y = -2
>>> np.piecewise(y, [y < 0, y >= 0], [lambda x: -x, lambda x: x])