jax.nn.squareplus#
- jax.nn.squareplus(x, b=4)[source]#
Squareplus activation function.
Computes the element-wise function
\[\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}\]as described in https://arxiv.org/abs/2112.11687.
- Parameters:
x (ArrayLike) – input array
b (ArrayLike) – smoothness parameter
- Return type: