jax.numpy.piecewise#
- jax.numpy.piecewise(x, condlist, funclist, *args, **kw)[source]#
Evaluate a function defined piecewise across the domain.
JAX implementation of
numpy.piecewise()
, in terms ofjax.lax.switch()
.Note
Unlike
numpy.piecewise()
,jax.numpy.piecewise()
requires functions infunclist
to be traceable by JAX, as it is implemented viajax.lax.switch()
.- Parameters:
x (ArrayLike) – array of input values.
condlist (Array | Sequence[ArrayLike]) – boolean array or sequence of boolean arrays corresponding to the functions in
funclist
. If a sequence of arrays, the length of each array must match the length ofx
funclist (list[ArrayLike | Callable[..., Array]]) – list of arrays or functions; must either be the same length as
condlist
, or have lengthlen(condlist) + 1
, in which case the last entry is the default applied when none of the conditions are True. Alternatively, entries offunclist
may be numerical values, in which case they indicate a constant function.args – additional arguments are passed to each function in
funclist
.kwargs – additional arguments are passed to each function in
funclist
.
- Returns:
An array which is the result of evaluating the functions on
x
at the specified conditions.- Return type:
See also
jax.lax.switch()
: choose between N functions based on an index.jax.lax.cond()
: choose between two functions based on a boolean condition.jax.numpy.where()
: choose between two results based on a boolean mask.jax.lax.select()
: choose between two results based on a boolean mask.jax.lax.select_n()
: choose between N results based on a boolean mask.
Examples
Here’s an example of a function which is zero for negative values, and linear for positive values:
>>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4])
>>> condlist = [x < 0, x >= 0] >>> funclist = [lambda x: 0 * x, lambda x: x] >>> jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
funclist
can also contain a simple scalar value for constant functions:>>> condlist = [x < 0, x >= 0] >>> funclist = [0, lambda x: x] >>> jnp.piecewise(x, condlist, funclist) Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32)
You can specify a default value by appending an extra condition to
funclist
:>>> condlist = [x < -1, x > 1] >>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0] >>> jnp.piecewise(x, condlist, funclist) Array([-3, -2, -1, 0, 0, 0, 1, 2, 3], dtype=int32)
condlist
may also be a simple array of scalar conditions, in which case the associated function applies to the whole range>>> condlist = jnp.array([False, True, False]) >>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100] >>> jnp.piecewise(x, condlist, funclist) Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32)