jax.nn.sparse_plus

Contents

jax.nn.sparse_plus#

jax.nn.sparse_plus(x)#

Sparse plus function.

Computes the function:

\[\begin{split}\mathrm{sparse\_plus}(x) = \begin{cases} 0, & x \leq -1\\ \frac{1}{4}(x+1)^2, & -1 < x < 1 \\ x, & 1 \leq x \end{cases}\end{split}\]

This is the twin function of the softplus activation ensuring a zero output for inputs less than -1 and a linear output for inputs greater than 1, while remaining smooth, convex, monotonic by an adequate definition between -1 and 1.

Parameters:

x (jax.typing.ArrayLike) – input (float)

Return type:

Array