jax.nn.normalize

jax.nn.normalize(x, axis=-1, mean=None, variance=None, epsilon=1e-05)[source]

Normalizes an array by subtracting mean and dividing by sqrt(var).