jax.scipy.special.kl_div

Contents

jax.scipy.special.kl_div#

jax.scipy.special.kl_div(p, q)[source]#

The Kullback-Leibler divergence.

JAX implementation of scipy.special.kl_div.

\[\begin{split} \mathrm{kl\_div}(p, q) = \begin{cases} p\log(p/q) & p>0,q>0\\ q & p=0,q\ge 0\\ \infty & \mathrm{otherwise} \end{cases}\end{split}\]
Parameters:
  • p (jax.typing.ArrayLike) – arraylike, real-valued.

  • q (jax.typing.ArrayLike) – arraylike, real-valued.

Returns:

array of KL-divergence values

Return type:

Array