Pseudo Random Numbers in JAX

Pseudo Random Numbers in JAX#

Open in Colab Open in Kaggle

Authors: Matteo Hessel & Rosalia Schneider

In this section we focus on pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.

PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the seed, and each step of random sampling is a deterministic function of some state that is carried over from a sample to the next.

Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.

To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section.

Random numbers in NumPy#

Pseudo random number generation is natively supported in NumPy by the numpy.random module.

In NumPy, pseudo random number generation is based on a global state.

This can be set to a deterministic initial condition using random.seed(SEED).

import numpy as np
np.random.seed(0)

You can inspect the content of the state using the following command.

def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')

print_truncated_random_state()
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...

The state is updated by each call to a random function:

np.random.seed(0)

print_truncated_random_state()

_ = np.random.uniform()

print_truncated_random_state()
('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044,
       2481403966, 4042607538,  337614300, 3232553940, 1018809052,
       3202401494, 1775180719, 3192392114,  594215549,  184016991,
        829906058,  610491522, 3879932251, 3139825610,  297902587,
       4075895579, 2943625357, 3530655617, 1423771745, 2135928312,
       2891506774, 1066338622,  135451537,  933040465, 2759011858,
       2273819758, 3545703099, 2516396728, 127 ...
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,
       3904844661,  676747479, 2085143622, 1056793272, 3812477442,
       2168787041,  275552121, 2696932952, 3432054210, 1657102335,
       3518946594,  962584079, 1051271004, 3806145045, 1414436097,
       2032348584, 1661738718, 1116708477, 2562755208, 3176189976,
        696824676, 2399811678, 3992505346,  569184356, 2626558620,
        136797809, 4273176064,  296167901, 343 ...

NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing:

np.random.seed(0)
print(np.random.uniform(size=3))
[0.5488135  0.71518937 0.60276338]

NumPy provides a sequential equivalent guarantee, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:

np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135  0.71518937 0.60276338]
all at once:  [0.5488135  0.71518937 0.60276338]

Random numbers in JAX#

JAX’s random number generation differs from NumPy’s in important ways. The reason is that NumPy’s PRNG design makes it hard to simultaneously guarantee a number of desirable properties for JAX, specifically that code must be:

  1. reproducible,

  2. parallelizable,

  3. vectorisable.

We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:

import numpy as np

np.random.seed(0)

def bar(): return np.random.uniform()
def baz(): return np.random.uniform()

def foo(): return bar() + 2 * baz()

print(foo())
1.9791922366721637

The function foo sums two scalars sampled from a uniform distribution.

The output of this code can only satisfy requirement #1 if we assume a specific order of execution for bar() and baz(), as native Python does.

This doesn’t seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX.

Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize bar and baz when jitting as these functions don’t actually depend on each other.

To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a key .

from jax import random

key = random.key(42)

print(key)
[ 0 42]

A single key is an array of scalar shape () and key element type.

‘Random key’ is essentially just another word for ‘random seed’. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:

print(random.normal(key))
print(random.normal(key))
-0.18471184
-0.18471184

Note: Feeding the same key to different random functions can result in correlated outputs, which is generally undesirable.

The rule of thumb is: never reuse keys (unless you want identical outputs).

In order to generate different and independent samples, you must split() the key yourself whenever you want to call a random function:

print("old key", key)
new_key, subkey = random.split(key)
del key  # The old key is discarded -- we must never use it again.
normal_sample = random.normal(subkey)
print(r"    \---SPLIT --> new key   ", new_key)
print(r"             \--> new subkey", subkey, "--> normal", normal_sample)
del subkey  # The subkey is also discarded after use.

# Note: you don't actually need to `del` keys -- that's just for emphasis.
# Not reusing the same values is enough.

key = new_key  # If we wanted to do this again, we would use new_key as the key.
old key [ 0 42]
    \---SPLIT --> new key    [2465931498 3679230171]
             \--> new subkey [255383827 267815257] --> normal 1.3694694

split() is a deterministic function that converts one key into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the new_key, and can safely use the unique extra key (called subkey) as input into a random function, and then discard it forever.

If you wanted to get another sample from the normal distribution, you would split key again, and so on. The crucial point is that you never use the same PRNG key twice. Since split() takes a key as its argument, we must throw away that old key when we split it.

It doesn’t matter which part of the output of split(key) we call key, and which we call subkey. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they’re consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.

Usually, the above example would be written concisely as

key, subkey = random.split(key)

which discards the old key automatically.

It’s worth noting that split() can create as many keys as you need, not just 2:

key, *forty_two_subkeys = random.split(key, num=43)

Another difference between NumPy’s and JAX’s random modules relates to the sequential equivalence guarantee mentioned above.

As in NumPy, JAX’s random module also allows sampling of vectors of numbers. However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).

In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying shape=(3,):

key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
individually: [-0.04838839  0.10796146 -1.2226542 ]
all at once:  [ 0.18693541 -1.2806507  -1.5593133 ]

Note that contrary to our recommendation above, we use key directly as an input to random.normal() in the second example. This is because we won’t reuse it anywhere else, so we don’t violate the single-use principle.