jax.random.binomial

Contents

jax.random.binomial#

jax.random.binomial(key, n, p, shape=None, dtype=<class 'float'>)[source]#

Sample Binomial random values with given shape and float dtype.

The values are returned according to the probability mass function:

\[f(k;n,p) = \binom{n}{k}p^k(1-p)^{n-k}\]

on the domain \(0 < p < 1\), and where \(n\) is a nonnegative integer representing the number of trials and \(p\) is a float representing the probability of success of an individual trial.

Parameters:
  • key (Array) – a PRNG key used as the random key.

  • n (Array | ndarray | bool_ | number | bool | int | float | complex) – a float or array of floats broadcast-compatible with shape representing the number of trials.

  • p (Array | ndarray | bool_ | number | bool | int | float | complex) – a float or array of floats broadcast-compatible with shape representing the probability of success of an individual trial.

  • shape (Sequence[int] | None) – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible with n and p. The default (None) produces a result shape equal to np.broadcast(n, p).shape.

  • dtype (str | type[Any] | dtype | SupportsDType) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).

Returns:

A random array with the specified dtype and with shape given by np.broadcast(n, p).shape.

Return type:

Array