jax.scipy.special.logit#

jax.scipy.special.logit(x) = <jax._src.custom_derivatives.custom_jvp object>[source]#

Logit ufunc for ndarrays.

LAX-backend implementation of scipy.special.logit().

Original docstring below.

The logit function is defined as logit(p) = log(p/(1-p)). Note that logit(0) = -inf, logit(1) = inf, and logit(p) for p<0 or p>1 yields nan.

Parameters
  • x (ndarray) – The ndarray to apply logit to element-wise.

  • out (ndarray, optional) – Optional output array for the function results

Returns

An ndarray of the same shape as x. Its entries are logit of the corresponding entry of x.

Return type

scalar or ndarray