# Pseudo Random Numbers in JAX

## Contents

# Pseudo Random Numbers in JAX#

*Authors: Matteo Hessel & Rosalia Schneider*

In this section we focus on pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.

PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`

, and each step of random sampling is a deterministic function of some `state`

that is carried over from a sample to the next.

Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.

To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section.

## Random numbers in NumPy#

Pseudo random number generation is natively supported in NumPy by the `numpy.random`

module.

In NumPy, pseudo random number generation is based on a global `state`

.

This can be set to a deterministic initial condition using `random.seed(SEED)`

.

```
import numpy as np
np.random.seed(0)
```

You can inspect the content of the state using the following command.

```
def print_truncated_random_state():
"""To avoid spamming the outputs, print only part of the state."""
full_random_state = np.random.get_state()
print(str(full_random_state)[:460], '...')
print_truncated_random_state()
```

```
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
```

The `state`

is updated by each call to a random function:

```
np.random.seed(0)
print_truncated_random_state()
_ = np.random.uniform()
print_truncated_random_state()
```

```
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,
2481403966, 4042607538, 337614300, 3232553940, 1018809052,
3202401494, 1775180719, 3192392114, 594215549, 184016991,
829906058, 610491522, 3879932251, 3139825610, 297902587,
4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
2891506774, 1066338622, 135451537, 933040465, 2759011858,
2273819758, 3545703099, 2516396728, 127 ...
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
3904844661, 676747479, 2085143622, 1056793272, 3812477442,
2168787041, 275552121, 2696932952, 3432054210, 1657102335,
3518946594, 962584079, 1051271004, 3806145045, 1414436097,
2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
696824676, 2399811678, 3992505346, 569184356, 2626558620,
136797809, 4273176064, 296167901, 343 ...
```

NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing:

```
np.random.seed(0)
print(np.random.uniform(size=3))
```

```
[0.5488135 0.71518937 0.60276338]
```

NumPy provides a *sequential equivalent guarantee*, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:

```
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
```

```
individually: [0.5488135 0.71518937 0.60276338]
all at once: [0.5488135 0.71518937 0.60276338]
```

## Random numbers in JAX#

JAXâ€™s random number generation differs from NumPyâ€™s in important ways. The reason is that NumPyâ€™s PRNG design makes it hard to simultaneously guarantee a number of desirable properties for JAX, specifically that code must be:

reproducible,

parallelizable,

vectorisable.

We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:

```
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
```

```
1.9791922366721637
```

The function `foo`

sums two scalars sampled from a uniform distribution.

The output of this code can only satisfy requirement #1 if we assume a specific order of execution for `bar()`

and `baz()`

, as native Python does.

This doesnâ€™t seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX.

Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar`

and `baz`

when jitting as these functions donâ€™t actually depend on each other.

To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key`

.

```
from jax import random
key = random.PRNGKey(42)
print(key)
```

```
[ 0 42]
```

A key is just an array of shape `(2,)`

.

â€˜Random keyâ€™ is essentially just another word for â€˜random seedâ€™. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:

```
print(random.normal(key))
print(random.normal(key))
```

```
-0.18471184
-0.18471184
```

**Note:** Feeding the same key to different random functions can result in correlated outputs, which is generally undesirable.

**The rule of thumb is: never reuse keys (unless you want identical outputs).**

In order to generate different and independent samples, you must `split()`

the key *yourself* whenever you want to call a random function:

```
print("old key", key)
new_key, subkey = random.split(key)
del key # The old key is discarded -- we must never use it again.
normal_sample = random.normal(subkey)
print(r" \---SPLIT --> new key ", new_key)
print(r" \--> new subkey", subkey, "--> normal", normal_sample)
del subkey # The subkey is also discarded after use.
# Note: you don't actually need to `del` keys -- that's just for emphasis.
# Not reusing the same values is enough.
key = new_key # If we wanted to do this again, we would use new_key as the key.
```

```
old key [ 0 42]
\---SPLIT --> new key [2465931498 3679230171]
\--> new subkey [255383827 267815257] --> normal 1.3694694
```

`split()`

is a deterministic function that converts one `key`

into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`

, and can safely use the unique extra key (called `subkey`

) as input into a random function, and then discard it forever.

If you wanted to get another sample from the normal distribution, you would split `key`

again, and so on. The crucial point is that you never use the same PRNGKey twice. Since `split()`

takes a key as its argument, we must throw away that old key when we split it.

It doesnâ€™t matter which part of the output of `split(key)`

we call `key`

, and which we call `subkey`

. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how theyâ€™re consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.

Usually, the above example would be written concisely as

```
key, subkey = random.split(key)
```

which discards the old key automatically.

Itâ€™s worth noting that `split()`

can create as many keys as you need, not just 2:

```
key, *forty_two_subkeys = random.split(key, num=43)
```

Another difference between NumPyâ€™s and JAXâ€™s random modules relates to the sequential equivalence guarantee mentioned above.

As in NumPy, JAXâ€™s random module also allows sampling of vectors of numbers. However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).

In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying `shape=(3,)`

:

```
key = random.PRNGKey(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.PRNGKey(42)
print("all at once: ", random.normal(key, shape=(3,)))
```

```
individually: [-0.04838839 0.10796146 -1.2226542 ]
all at once: [ 0.18693541 -1.2806507 -1.5593133 ]
```

Note that contrary to our recommendation above, we use `key`

directly as an input to `random.normal()`

in the second example. This is because we wonâ€™t reuse it anywhere else, so we donâ€™t violate the single-use principle.