jax.scipy.stats.multinomial.logpmf

Contents

jax.scipy.stats.multinomial.logpmf#

jax.scipy.stats.multinomial.logpmf(x, n, p)[source]#

Log of the Multinomial probability mass function.

LAX-backend implementation of scipy.stats._multivariate.logpmf().

Original docstring below.

Parameters:
  • x (array_like) – Quantiles, with the last axis of x denoting the components.

  • n (int) – Number of trials

  • p (array_like) – Probability of a trial falling into each category; should sum to 1

Returns:

logpmf – Log of the probability mass function evaluated at x

Return type:

ndarray or scalar