jax.numpy.piecewise#

jax.numpy.piecewise(x, condlist, funclist, *args, **kw)[source]#

Evaluate a function defined piecewise across the domain.

JAX implementation of numpy.piecewise(), in terms of jax.lax.switch().

Note

Unlike numpy.piecewise(), jax.numpy.piecewise() requires functions in funclist to be traceable by JAX, as it is implemented via jax.lax.switch().

Parameters:
  • x (ArrayLike) – array of input values.

  • condlist (Array | Sequence[ArrayLike]) – boolean array or sequence of boolean arrays corresponding to the functions in funclist. If a sequence of arrays, the length of each array must match the length of x

  • funclist (list[ArrayLike | Callable[..., Array]]) – list of arrays or functions; must either be the same length as condlist, or have length len(condlist) + 1, in which case the last entry is the default applied when none of the conditions are True. Alternatively, entries of funclist may be numerical values, in which case they indicate a constant function.

  • args – additional arguments are passed to each function in funclist.

  • kwargs – additional arguments are passed to each function in funclist.

Returns:

An array which is the result of evaluating the functions on x at the specified conditions.

Return type:

Array

See also

Examples

Here’s an example of a function which is zero for negative values, and linear for positive values:

>>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
>>> condlist = [x < 0, x >= 0]
>>> funclist = [lambda x: 0 * x, lambda x: x]
>>> jnp.piecewise(x, condlist, funclist)
Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

funclist can also contain a simple scalar value for constant functions:

>>> condlist = [x < 0, x >= 0]
>>> funclist = [0, lambda x: x]
>>> jnp.piecewise(x, condlist, funclist)
Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)

You can specify a default value by appending an extra condition to funclist:

>>> condlist = [x < -1, x > 1]
>>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0]
>>> jnp.piecewise(x, condlist, funclist)
Array([-3, -2,  -1,  0,  0,  0,  1,  2, 3], dtype=int32)

condlist may also be a simple array of scalar conditions, in which case the associated function applies to the whole range

>>> condlist = jnp.array([False, True, False])
>>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100]
>>> jnp.piecewise(x, condlist, funclist)
Array([-40, -30, -20, -10,   0,  10,  20,  30,  40], dtype=int32)