jax.nn package

Common functions for neural network libraries.

Activation functions

relu Rectified linear unit activation function.
sigmoid(x) Sigmoid activation function.
softplus(x) Softplus activation function.
soft_sign(x) Soft-sign activation function.
swish(x) Swish activation function.
log_sigmoid(x) Log-sigmoid activation function.
leaky_relu(x[, negative_slope]) Leaky rectified linear unit activation function.
hard_tanh(x) Hard \(\mathrm{tanh}\) activation function.
elu(x[, alpha]) Exponential linear unit activation function.
celu(x[, alpha]) Continuously-differentiable exponential linear unit activation.
selu(x) Scaled exponential linear unit activation.
gelu(x) Gaussian error linear unit activation function.
glu(x[, axis]) Gated linear unit activation function.

Other functions

softmax(x[, axis]) Softmax function.
log_softmax(x[, axis]) Log-Softmax function.
normalize(x[, axis, mean, variance, epsilon]) Normalizes an array by subtracting mean and dividing by sqrt(var).