jax.random.rademacher#

jax.random.rademacher(key, shape, dtype=<class 'int'>)[source]#

Sample from a Rademacher distribution.

The values are distributed according to the probability mass function:

\[f(k) = \frac{1}{2}(\delta(k - 1) + \delta(k + 1))\]

on the domain \(k \in \{-1, 1}\), where delta(x) is the dirac delta function.

Parameters
  • key (Union[Array, PRNGKeyArray]) – a PRNG key.

  • shape (Sequence[int]) – The shape of the returned samples.

  • dtype (Union[Any, str, dtype, SupportsDType]) – The type used for samples.

Return type

Array

Returns

A jnp.array of samples, of shape shape. Each element in the output has a 50% change of being 1 or -1.