jax.nn.squareplus

Contents

jax.nn.squareplus#

jax.nn.squareplus(x, b=4)[source]#

Squareplus activation function.

Computes the element-wise function

\[\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}\]

as described in https://arxiv.org/abs/2112.11687.

Parameters:
  • x (jax.typing.ArrayLike) – input array

  • b (jax.typing.ArrayLike) – smoothness parameter

Return type:

Array