jax.random.gamma#
- jax.random.gamma(key, a, shape=None, dtype=<class 'float'>)[source]#
Sample Gamma random values with given shape and float dtype.
The values are distributed according the the probability density function:
\[f(x;a) \propto x^{a - 1} e^{-x}\]on the domain \(0 \le x < \infty\), with \(a > 0\).
This is the standard gamma density, with a unit scale/rate parameter. Dividing the sample output by the rate is equivalent to sampling from gamma(a, rate), and multiplying the sample output by the scale is equivalent to sampling from gamma(a, scale).
- Parameters
key (
Union
[Array
,PRNGKeyArray
]) – a PRNG key used as the random key.a (
Union
[Array
,ndarray
,bool_
,number
,bool
,int
,float
,complex
]) – a float or array of floats broadcast-compatible withshape
representing the parameter of the distribution.shape (
Optional
[Sequence
[int
]]) – optional, a tuple of nonnegative integers specifying the result shape. Must be broadcast-compatible witha
. The default (None) produces a result shape equal toa.shape
.dtype (
Union
[Any
,str
,dtype
,SupportsDType
]) – optional, a float dtype for the returned values (default float64 if jax_enable_x64 is true, otherwise float32).
- Return type
- Returns
A random array with the specified dtype and with shape given by
shape
ifshape
is not None, or else bya.shape
.
See also
- loggammasample gamma values in log-space, which can provide improved
accuracy for small values of
a
.