jax.nn.standardize

Contents

jax.nn.standardize#

jax.nn.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#

Normalizes an array by subtracting mean and dividing by \(\sqrt{\mathrm{variance}}\).

Parameters:
  • x (ArrayLike)

  • axis (int | tuple[int, …] | None)

  • mean (ArrayLike | None)

  • variance (ArrayLike | None)

  • epsilon (ArrayLike)

  • where (ArrayLike | None)

Return type:

Array