Source code for jax._src.nn.functions

# Copyright 2019 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Shared neural network activations and other functions."""

from functools import partial
import operator
import warnings
import numpy as np
from typing import Any, Optional, Tuple, Union

import jax
from jax import custom_jvp
from jax._src import dtypes
from jax import lax
from jax import core
from jax.core import AxisName
from jax._src import util
from jax.scipy.special import expit
from jax.scipy.special import logsumexp as _logsumexp
import jax.numpy as jnp

Array = Any

# activations

[docs]@custom_jvp @jax.jit def relu(x: Array) -> Array: r"""Rectified linear unit activation function. Computes the element-wise function: .. math:: \mathrm{relu}(x) = \max(x, 0) except under differentiation, we take: .. math:: \nabla \mathrm{relu}(0) = 0 For more information see `Numerical influence of ReLU’(0) on backpropagation <https://openreview.net/forum?id=urrcVI-_jRm>`_. Args: x : input array """ return jnp.maximum(x, 0)
# For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
[docs]@jax.jit def softplus(x: Array) -> Array: r"""Softplus activation function. Computes the element-wise function .. math:: \mathrm{softplus}(x) = \log(1 + e^x) Args: x : input array """ return jnp.logaddexp(x, 0)
[docs]@jax.jit def soft_sign(x: Array) -> Array: r"""Soft-sign activation function. Computes the element-wise function .. math:: \mathrm{soft\_sign}(x) = \frac{x}{|x| + 1} Args: x : input array """ return x / (jnp.abs(x) + 1)
[docs]@jax.jit def sigmoid(x: Array) -> Array: r"""Sigmoid activation function. Computes the element-wise function: .. math:: \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}} Args: x : input array """ return lax.logistic(x)
[docs]@jax.jit def silu(x: Array) -> Array: r"""SiLU activation function. Computes the element-wise function: .. math:: \mathrm{silu}(x) = x \cdot \mathrm{sigmoid}(x) = \frac{x}{1 + e^{-x}} Args: x : input array """ return x * sigmoid(x)
swish = silu
[docs]@jax.jit def log_sigmoid(x: Array) -> Array: r"""Log-sigmoid activation function. Computes the element-wise function: .. math:: \mathrm{log\_sigmoid}(x) = \log(\mathrm{sigmoid}(x)) = -\log(1 + e^{-x}) Args: x : input array """ return -softplus(-x)
[docs]@jax.jit def elu(x: Array, alpha: Array = 1.0) -> Array: r"""Exponential linear unit activation function. Computes the element-wise function: .. math:: \mathrm{elu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(x) - 1\right), & x \le 0 \end{cases} Args: x : input array alpha : scalar or array of alpha values (default: 1.0) """ safe_x = jnp.where(x > 0, 0., x) return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))
[docs]@jax.jit def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array: r"""Leaky rectified linear unit activation function. Computes the element-wise function: .. math:: \mathrm{leaky\_relu}(x) = \begin{cases} x, & x \ge 0\\ \alpha x, & x < 0 \end{cases} where :math:`\alpha` = :code:`negative_slope`. Args: x : input array negative_slope : array or scalar specifying the negative slope (default: 0.01) """ return jnp.where(x >= 0, x, negative_slope * x)
[docs]@jax.jit def hard_tanh(x: Array) -> Array: r"""Hard :math:`\mathrm{tanh}` activation function. Computes the element-wise function: .. math:: \mathrm{hard\_tanh}(x) = \begin{cases} -1, & x < -1\\ x, & -1 \le x \le 1\\ 1, & 1 < x \end{cases} Args: x : input array """ return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))
[docs]@jax.jit def celu(x: Array, alpha: Array = 1.0) -> Array: r"""Continuously-differentiable exponential linear unit activation. Computes the element-wise function: .. math:: \mathrm{celu}(x) = \begin{cases} x, & x > 0\\ \alpha \left(\exp(\frac{x}{\alpha}) - 1\right), & x \le 0 \end{cases} For more information, see `Continuously Differentiable Exponential Linear Units <https://arxiv.org/pdf/1704.07483.pdf>`_. Args: x : input array alpha : array or scalar (default: 1.0) """ return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha)
[docs]@jax.jit def selu(x: Array) -> Array: r"""Scaled exponential linear unit activation. Computes the element-wise function: .. math:: \mathrm{selu}(x) = \lambda \begin{cases} x, & x > 0\\ \alpha e^x - \alpha, & x \le 0 \end{cases} where :math:`\lambda = 1.0507009873554804934193349852946` and :math:`\alpha = 1.6732632423543772848170429916717`. For more information, see `Self-Normalizing Neural Networks <https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf>`_. Args: x : input array """ alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 return scale * elu(x, alpha)
# TODO(phawkins): this jit was found to change numerics in a test. Debug this. # @partial(jax.jit, static_argnames=("approximate",))
[docs]def gelu(x: Array, approximate: bool = True) -> Array: r"""Gaussian error linear unit activation function. If ``approximate=False``, computes the element-wise function: .. math:: \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( \frac{x}{\sqrt{2}} \right) \right) If ``approximate=True``, uses the approximate formulation of GELU: .. math:: \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{tanh} \left( \sqrt{\frac{2}{\pi}} \left(x + 0.044715 x^3 \right) \right) \right) For more information, see `Gaussian Error Linear Units (GELUs) <https://arxiv.org/abs/1606.08415>`_, section 2. Args: x : input array approximate: whether to use the approximate or exact formulation. """ # Promote to nearest float-like dtype. x = x.astype(dtypes.to_inexact_dtype(x.dtype)) if approximate: sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3)))) return x * cdf else: sqrt_2 = np.sqrt(2).astype(x.dtype) return jnp.array(x * (lax.erf(x / sqrt_2) + 1) / 2, dtype=x.dtype)
[docs]@partial(jax.jit, static_argnames=("axis",)) def glu(x: Array, axis: int = -1) -> Array: """Gated linear unit activation function. Args: x : input array axis: the axis along which the split should be computed (default: -1) """ size = x.shape[axis] assert size % 2 == 0, "axis size must be divisible by 2" x1, x2 = jnp.split(x, 2, axis) return x1 * sigmoid(x2)
# other functions logsumexp = _logsumexp
[docs]@partial(jax.jit, static_argnames=("axis",)) def log_softmax(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales elements to the range :math:`[-\infty, 0)`. .. math :: \mathrm{log\_softmax}(x) = \log \left( \frac{\exp(x_i)}{\sum_j \exp(x_j)} \right) Args: x : input array axis: the axis or axes along which the :code:`log_softmax` should be computed. Either an integer or a tuple of integers. where: Elements to include in the :code:`log_softmax`. initial: The minimum value used to shift the input array. Must be present when :code:`where` is not None. """ x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) shifted = x - lax.stop_gradient(x_max) shifted_logsumexp = jnp.log( jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True)) return shifted - shifted_logsumexp
# TODO(phawkins): this jit was found to change numerics in a test. Debug this. #@partial(jax.jit, static_argnames=("axis",))
[docs]def softmax(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, where: Optional[Array] = None, initial: Optional[Array] = None) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` such that the elements along :code:`axis` sum to :math:`1`. .. math :: \mathrm{softmax}(x) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} Args: x : input array axis: the axis or axes along which the softmax should be computed. The softmax output summed across these dimensions should sum to :math:`1`. Either an integer or a tuple of integers. where: Elements to include in the :code:`softmax`. initial: The minimum value used to shift the input array. Must be present when :code:`where` is not None. """ x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) unnormalized = jnp.exp(x - lax.stop_gradient(x_max)) return unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
@partial(jax.jit, static_argnames=("axis",)) def standardize(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, mean: Optional[Array] = None, variance: Optional[Array] = None, epsilon: Array = 1e-5, where: Optional[Array] = None) -> Array: r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`.""" if mean is None: mean = jnp.mean(x, axis, keepdims=True, where=where) if variance is None: # this definition is traditionally seen as less accurate than jnp.var's # mean((x - mean(x))**2) but may be faster and even, given typical # activation distributions and low-precision arithmetic, more accurate # when used in neural network normalization layers variance = jnp.mean( jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean) return (x - mean) * lax.rsqrt(variance + epsilon)
[docs]def normalize(x: Array, axis: Optional[Union[int, Tuple[int, ...]]] = -1, mean: Optional[Array] = None, variance: Optional[Array] = None, epsilon: Array = 1e-5, where: Optional[Array] = None) -> Array: r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`.""" warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning) return standardize(x, axis, mean, variance, epsilon, where)
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis")) def _one_hot(x: Array, num_classes: int, *, dtype: Any, axis: Union[int, AxisName]) -> Array: num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) x = jnp.asarray(x) try: output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError(f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) return jnp.asarray(x == axis_idx, dtype=dtype) axis = operator.index(axis) # type: ignore[arg-type] lhs = lax.expand_dims(x, (axis,)) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis) return jnp.asarray(lhs == rhs, dtype=dtype)
[docs]def one_hot(x: Array, num_classes: int, *, dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array: """One-hot encodes the given indicies. Each index in the input ``x`` is encoded as a vector of zeros of length ``num_classes`` with the element at ``index`` set to one:: >>> jax.nn.one_hot(jnp.array([0, 1, 2]), 3) DeviceArray([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32) Indicies outside the range [0, num_classes) will be encoded as zeros:: >>> jax.nn.one_hot(jnp.array([-1, 3]), 3) DeviceArray([[0., 0., 0.], [0., 0., 0.]], dtype=float32) Args: x: A tensor of indices. num_classes: Number of classes in the one-hot dimension. dtype: optional, a float dtype for the returned values (default :obj:`jnp.float_`). axis: the axis or axes along which the function should be computed. """ num_classes = core.concrete_or_error( int, num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") return _one_hot(x, num_classes, dtype=dtype, axis=axis)
[docs]@jax.jit def relu6(x: Array) -> Array: r"""Rectified Linear Unit 6 activation function. Computes the element-wise function .. math:: \mathrm{relu6}(x) = \min(\max(x, 0), 6) Args: x : input array """ return jnp.minimum(jnp.maximum(x, 0), 6.)
[docs]@jax.jit def hard_sigmoid(x: Array) -> Array: r"""Hard Sigmoid activation function. Computes the element-wise function .. math:: \mathrm{hard\_sigmoid}(x) = \frac{\mathrm{relu6}(x + 3)}{6} Args: x : input array """ return relu6(x + 3.) / 6.
[docs]@jax.jit def hard_silu(x: Array) -> Array: r"""Hard SiLU activation function Computes the element-wise function .. math:: \mathrm{hard\_silu}(x) = x \cdot \mathrm{hard\_sigmoid}(x) Args: x : input array """ return x * hard_sigmoid(x)
hard_swish = hard_silu