jax.nn.gelu

Contents

jax.nn.gelu#

jax.nn.gelu(x, approximate=True)[source]#

Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right)\]

If approximate=True, uses the approximate formulation of GELU:

\[\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)\]

For more information, see Gaussian Error Linear Units (GELUs), section 2.

Parameters:
Return type:

Array