jax.numpy.logaddexp2

jax.numpy.logaddexp2 = <jax.custom_derivatives.custom_jvp object>[source]

Logarithm of the sum of exponentiations of the inputs in base-2.

LAX-backend implementation of logaddexp2(). Original docstring below.

logaddexp2(x1, x2, /, out=None, *, where=True, casting=’same_kind’, order=’K’, dtype=None, subok=True[, signature, extobj])

Calculates log2(2**x1 + 2**x2). This function is useful in machine learning when the calculated probabilities of events may be so small as to exceed the range of normal floating point numbers. In such cases the base-2 logarithm of the calculated probability can be used instead. This function allows adding probabilities stored in such a fashion.

Parameters

x2 (x1,) – Input values. If x1.shape != x2.shape, they must be broadcastable to a common shape (which becomes the shape of the output).

Returns

result – Base-2 logarithm of 2**x1 + 2**x2. This is a scalar if both x1 and x2 are scalars.

Return type

ndarray

See also

logaddexp

Logarithm of the sum of exponentiations of the inputs.

Notes

New in version 1.3.0.

Examples

>>> prob1 = np.log2(1e-50)
>>> prob2 = np.log2(2.5e-50)
>>> prob12 = np.logaddexp2(prob1, prob2)
>>> prob1, prob2, prob12
(-166.09640474436813, -164.77447664948076, -164.28904982231052)
>>> 2**prob12
3.4999999999999914e-50