jax.nn package¶

Common functions for neural network libraries.

Activation functions¶

relu

Rectified linear unit activation function.

relu6(x)

Rectified Linear Unit 6 activation function.

sigmoid(x)

Sigmoid activation function.

softplus(x)

Softplus activation function.

soft_sign(x)

Soft-sign activation function.

silu(x)

SiLU activation function.

swish(x)

SiLU activation function.

log_sigmoid(x)

Log-sigmoid activation function.

leaky_relu(x[, negative_slope])

Leaky rectified linear unit activation function.

hard_sigmoid(x)

Hard Sigmoid activation function.

hard_silu(x)

Hard SiLU activation function

hard_swish(x)

Hard SiLU activation function

hard_tanh(x)

Hard \(\mathrm{tanh}\) activation function.

elu(x[, alpha])

Exponential linear unit activation function.

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

selu(x)

Scaled exponential linear unit activation.

gelu(x[, approximate])

Gaussian error linear unit activation function.

glu(x[, axis])

Gated linear unit activation function.

Other functions¶

softmax(x[, axis])

Softmax function.

log_softmax(x[, axis])

Log-Softmax function.

normalize(x[, axis, mean, variance, epsilon])

Normalizes an array by subtracting mean and dividing by sqrt(var).

one_hot(x, num_classes, *[, dtype])

One-hot encodes the given indicies.