jax.nn.relu

Contents

jax.nn.relu#

jax.nn.relu(x) = <jax._src.custom_derivatives.custom_jvp object>[source]#

Rectified linear unit activation function.

Computes the element-wise function:

\[\mathrm{relu}(x) = \max(x, 0)\]

except under differentiation, we take:

\[\nabla \mathrm{relu}(0) = 0\]

For more information see Numerical influence of ReLU’(0) on backpropagation.

Parameters:

x (Union[Array, ndarray, bool_, number, bool, int, float, complex]) – input array

Return type:

Array

Returns:

An array.

Example

>>> jax.nn.relu(jax.numpy.array([-2., -1., -0.5, 0, 0.5, 1., 2.]))
Array([0. , 0. , 0. , 0. , 0.5, 1. , 2. ], dtype=float32)

See also

relu6()