jax.nn module

jax.nn module#

Common functions for neural network libraries.

Activation functions#

relu(x)

Rectified linear unit activation function.

relu6(x)

Rectified Linear Unit 6 activation function.

sigmoid(x)

Sigmoid activation function.

softplus(x)

Softplus activation function.

sparse_plus(x)

Sparse plus function.

soft_sign(x)

Soft-sign activation function.

silu(x)

SiLU (aka swish) activation function.

swish(x)

SiLU (aka swish) activation function.

log_sigmoid(x)

Log-sigmoid activation function.

leaky_relu(x[, negative_slope])

Leaky rectified linear unit activation function.

hard_sigmoid(x)

Hard Sigmoid activation function.

hard_silu(x)

Hard SiLU (swish) activation function

hard_swish(x)

Hard SiLU (swish) activation function

hard_tanh(x)

Hard \(\mathrm{tanh}\) activation function.

elu(x[, alpha])

Exponential linear unit activation function.

celu(x[, alpha])

Continuously-differentiable exponential linear unit activation.

selu(x)

Scaled exponential linear unit activation.

gelu(x[, approximate])

Gaussian error linear unit activation function.

glu(x[, axis])

Gated linear unit activation function.

squareplus(x[, b])

Squareplus activation function.

mish(x)

Mish activation function.

Other functions#

softmax(x[, axis, where, initial])

Softmax function.

log_softmax(x[, axis, where, initial])

Log-Softmax function.

logsumexp(a[, axis, b, keepdims, ...])

Compute the log of the sum of exponentials of input elements.

standardize(x[, axis, mean, variance, ...])

Normalizes an array by subtracting mean and dividing by \(\sqrt{\mathrm{variance}}\).

one_hot(x, num_classes, *[, dtype, axis])

One-hot encodes the given indices.