jax.scipy.special.logit

Contents

jax.scipy.special.logit#

jax.scipy.special.logit(x) = <jax._src.custom_derivatives.custom_jvp object>[source]#

Logit ufunc for ndarrays.

LAX-backend implementation of scipy.special.logit().

Original docstring below.

The logit function is defined as logit(p) = log(p/(1-p)). Note that logit(0) = -inf, logit(1) = inf, and logit(p) for p<0 or p>1 yields nan.

Parameters:
  • x (ndarray) – The ndarray to apply logit to element-wise.

  • out (ndarray, optional) – Optional output array for the function results

Returns:

An ndarray of the same shape as x. Its entries are logit of the corresponding entry of x.

Return type:

scalar or ndarray