# Source code for jax.nn.functions

# Copyright 2019 Google LLC
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

"""Shared neural network activations and other functions."""

import numpy as np

from jax import custom_jvp
from jax import dtypes
from jax import lax
from jax import core
from jax.scipy.special import expit
import jax.numpy as jnp

# activations

[docs]@custom_jvp
def relu(x):
r"""Rectified linear unit activation function.

Computes the element-wise function:

.. math::
\mathrm{relu}(x) = \max(x, 0)
"""
return jnp.maximum(x, 0)
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))

[docs]def softplus(x):
r"""Softplus activation function.

Computes the element-wise function

.. math::
\mathrm{softplus}(x) = \log(1 + e^x)
"""

[docs]def soft_sign(x):
r"""Soft-sign activation function.

Computes the element-wise function

.. math::
\mathrm{soft\_sign}(x) = \frac{x}{|x| + 1}
"""
return x / (jnp.abs(x) + 1)

[docs]def sigmoid(x):
r"""Sigmoid activation function.

Computes the element-wise function:

.. math::
\mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}}
"""
return expit(x)

[docs]def silu(x):
r"""SiLU activation function.

Computes the element-wise function:

.. math::
\mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}}
"""
return x * sigmoid(x)

swish = silu

[docs]def log_sigmoid(x):
r"""Log-sigmoid activation function.

Computes the element-wise function:

.. math::
\mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x})
"""
return -softplus(-x)

[docs]def elu(x, alpha=1.0):
r"""Exponential linear unit activation function.

Computes the element-wise function:

.. math::
\mathrm{elu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(x) - 1\right), & x \le 0
\end{cases}
"""
safe_x = jnp.where(x > 0, 0., x)
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))

[docs]def leaky_relu(x, negative_slope=1e-2):
r"""Leaky rectified linear unit activation function.

Computes the element-wise function:

.. math::
\mathrm{leaky\_relu}(x) = \begin{cases}
x, & x \ge 0\\
\alpha x, & x < 0
\end{cases}

where :math:\alpha = :code:negative_slope.
"""
return jnp.where(x >= 0, x, negative_slope * x)

[docs]def hard_tanh(x):
r"""Hard :math:\mathrm{tanh} activation function.

Computes the element-wise function:

.. math::
\mathrm{hard\_tanh}(x) = \begin{cases}
-1, & x < -1\\
x, & 0 \le x \le 1\\
1, & 1 < x
\end{cases}
"""
return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))

[docs]def celu(x, alpha=1.0):
r"""Continuously-differentiable exponential linear unit activation.

Computes the element-wise function:

.. math::
\mathrm{celu}(x) = \begin{cases}
x, & x > 0\\
\alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0
\end{cases}

Continuously Differentiable Exponential Linear Units
<https://arxiv.org/pdf/1704.07483.pdf>_."""
return jnp.where(x > 0, x, alpha * jnp.expm1(x / alpha))

[docs]def selu(x):
r"""Scaled exponential linear unit activation.

Computes the element-wise function:

.. math::
\mathrm{selu}(x) = \lambda \begin{cases}
x, & x > 0\\
\alpha e^x - \alpha, & x \le 0
\end{cases}

where :math:\lambda = 1.0507009873554804934193349852946 and
:math:\alpha = 1.6732632423543772848170429916717.

Self-Normalizing Neural Networks
<https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>_.
"""
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
return scale * elu(x, alpha)

[docs]def gelu(x, approximate: bool = True):
r"""Gaussian error linear unit activation function.

If approximate=False, computes the element-wise function:

.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left(
\frac{x}{\sqrt{2}} \right) \right)

If approximate=True, uses the approximate formulation of GELU:

.. math::
\mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left(
\sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right)

<https://arxiv.org/abs/1606.08415>_, section 2.

Args:
approximate: whether to use the approximate or exact formulation.
"""
if approximate:
sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)
cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
return x * cdf
else:
return jnp.array(x * (lax.erf(x / np.sqrt(2)) + 1) / 2, dtype=x.dtype)

[docs]def glu(x, axis=-1):
"""Gated linear unit activation function."""
size = x.shape[axis]
assert size % 2 == 0, "axis size must be divisible by 2"
x1, x2 = jnp.split(x, 2, axis)
return x1 * sigmoid(x2)

# other functions

[docs]def log_softmax(x, axis=-1):
r"""Log-Softmax function.

Computes the logarithm of the :code:softmax function, which rescales
elements to the range :math:[-\infty, 0).

.. math ::
\mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)}
\right)

Args:
axis: the axis or axes along which the :code:log_softmax should be
computed. Either an integer or a tuple of integers.
"""
shifted = x - lax.stop_gradient(x.max(axis, keepdims=True))
return shifted - jnp.log(jnp.sum(jnp.exp(shifted), axis, keepdims=True))

[docs]def softmax(x, axis=-1):
r"""Softmax function.

Computes the function which rescales elements to the range :math:[0, 1]
such that the elements along :code:axis sum to :math:1.

.. math ::
\mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}

Args:
axis: the axis or axes along which the softmax should be computed. The
softmax output summed across these dimensions should sum to :math:1.
Either an integer or a tuple of integers.
"""
unnormalized = jnp.exp(x - lax.stop_gradient(x.max(axis, keepdims=True)))
return unnormalized / unnormalized.sum(axis, keepdims=True)

[docs]def normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-5):
"""Normalizes an array by subtracting mean and dividing by sqrt(var)."""
if mean is None:
mean = jnp.mean(x, axis, keepdims=True)
if variance is None:
# this definition is traditionally seen as less accurate than jnp.var's
# mean((x - mean(x))**2) but may be faster and even, given typical
# activation distributions and low-precision arithmetic, more accurate
# when used in neural network normalization layers
variance = jnp.mean(jnp.square(x), axis, keepdims=True) - jnp.square(mean)
return (x - mean) * lax.rsqrt(variance + epsilon)

[docs]def one_hot(x, num_classes, *, dtype=jnp.float64):
"""One-hot encodes the given indicies.

Each index in the input x is encoded as a vector of zeros of length
num_classes with the element at index set to one::

>>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3)
DeviceArray([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)

Indicies outside the range [0, num_classes) will be encoded as zeros::

>>> jax.nn.one_hot(jnp.array([-1, 3]), 3)
DeviceArray([[0., 0., 0.],
[0., 0., 0.]], dtype=float32)

Args:
x: A tensor of indices.
num_classes: Number of classes in the one-hot dimension.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
"""
num_classes = core.concrete_or_error(
int, num_classes,
"The error arose in jax.nn.one_hot argument num_classes.")
dtype = dtypes.canonicalize_dtype(dtype)
x = jnp.asarray(x)
lhs = x[..., jnp.newaxis]
return jnp.array(lhs == rhs, dtype=dtype)

[docs]def relu6(x):
r"""Rectified Linear Unit 6 activation function.

Computes the element-wise function

.. math::
\mathrm{relu6}(x) = \min(\max(x, 0), 6)
"""
return jnp.minimum(jnp.maximum(x, 0), 6.)

[docs]def hard_sigmoid(x):
r"""Hard Sigmoid activation function.

Computes the element-wise function

.. math::
\mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6}
"""
return relu6(x + 3.) / 6.

[docs]def hard_silu(x):
r"""Hard SiLU activation function

Computes the element-wise function

.. math::
\mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x)
"""
return x * hard_sigmoid(x)

hard_swish = hard_silu