jax.nn.standardize

Contents

jax.nn.standardize#

jax.nn.standardize(x, axis=-1, mean=None, variance=None, epsilon=1e-05, where=None)[source]#

Normalizes an array by subtracting mean and dividing by \(\sqrt{\mathrm{variance}}\).

Parameters:
  • x (ArrayLike) –

  • axis (int | tuple[int, …] | None) –

  • mean (ArrayLike | None) –

  • variance (ArrayLike | None) –

  • epsilon (ArrayLike) –

  • where (ArrayLike | None) –

Return type:

Array