Pseudo Random Numbers in JAX
Pseudo Random Numbers in JAX#
Authors: Matteo Hessel & Rosalia Schneider
In this section we focus on pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.
PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the
seed, and each step of random sampling is a deterministic function of some
state that is carried over from a sample to the next.
Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.
To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section.
Random numbers in NumPy#
Pseudo random number generation is natively supported in NumPy by the
In NumPy, pseudo random number generation is based on a global
This can be set to a deterministic initial condition using
import numpy as np np.random.seed(0)
You can inspect the content of the state using the following command.
def print_truncated_random_state(): """To avoid spamming the outputs, print only part of the state.""" full_random_state = np.random.get_state() print(str(full_random_state)[:460], '...') print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044, 2481403966, 4042607538, 337614300, 3232553940, 1018809052, 3202401494, 1775180719, 3192392114, 594215549, 184016991, 829906058, 610491522, 3879932251, 3139825610, 297902587, 4075895579, 2943625357, 3530655617, 1423771745, 2135928312, 2891506774, 1066338622, 135451537, 933040465, 2759011858, 2273819758, 3545703099, 2516396728, 127 ...
state is updated by each call to a random function:
np.random.seed(0) print_truncated_random_state() _ = np.random.uniform() print_truncated_random_state()
('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044, 2481403966, 4042607538, 337614300, 3232553940, 1018809052, 3202401494, 1775180719, 3192392114, 594215549, 184016991, 829906058, 610491522, 3879932251, 3139825610, 297902587, 4075895579, 2943625357, 3530655617, 1423771745, 2135928312, 2891506774, 1066338622, 135451537, 933040465, 2759011858, 2273819758, 3545703099, 2516396728, 127 ... ('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660, 3904844661, 676747479, 2085143622, 1056793272, 3812477442, 2168787041, 275552121, 2696932952, 3432054210, 1657102335, 3518946594, 962584079, 1051271004, 3806145045, 1414436097, 2032348584, 1661738718, 1116708477, 2562755208, 3176189976, 696824676, 2399811678, 3992505346, 569184356, 2626558620, 136797809, 4273176064, 296167901, 343 ...
NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing:
[0.5488135 0.71518937 0.60276338]
NumPy provides a sequential equivalent guarantee, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:
np.random.seed(0) print("individually:", np.stack([np.random.uniform() for _ in range(3)])) np.random.seed(0) print("all at once: ", np.random.uniform(size=3))
individually: [0.5488135 0.71518937 0.60276338] all at once: [0.5488135 0.71518937 0.60276338]
Random numbers in JAX#
JAX’s random number generation differs from NumPy’s in important ways. The reason is that NumPy’s PRNG design makes it hard to simultaneously guarantee a number of desirable properties for JAX, specifically that code must be:
We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:
import numpy as np np.random.seed(0) def bar(): return np.random.uniform() def baz(): return np.random.uniform() def foo(): return bar() + 2 * baz() print(foo())
foo sums two scalars sampled from a uniform distribution.
The output of this code can only satisfy requirement #1 if we assume a specific order of execution for
baz(), as native Python does.
This doesn’t seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX.
Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize
baz when jitting as these functions don’t actually depend on each other.
To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a
from jax import random key = random.PRNGKey(42) print(key)
[ 0 42]
A key is just an array of shape
‘Random key’ is essentially just another word for ‘random seed’. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:
Note: Feeding the same key to different random functions can result in correlated outputs, which is generally undesirable.
The rule of thumb is: never reuse keys (unless you want identical outputs).
In order to generate different and independent samples, you must
split() the key yourself whenever you want to call a random function:
print("old key", key) new_key, subkey = random.split(key) del key # The old key is discarded -- we must never use it again. normal_sample = random.normal(subkey) print(r" \---SPLIT --> new key ", new_key) print(r" \--> new subkey", subkey, "--> normal", normal_sample) del subkey # The subkey is also discarded after use. # Note: you don't actually need to `del` keys -- that's just for emphasis. # Not reusing the same values is enough. key = new_key # If we wanted to do this again, we would use new_key as the key.
old key [ 0 42] \---SPLIT --> new key [2465931498 3679230171] \--> new subkey [255383827 267815257] --> normal 1.3694694
split() is a deterministic function that converts one
key into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the
new_key, and can safely use the unique extra key (called
subkey) as input into a random function, and then discard it forever.
If you wanted to get another sample from the normal distribution, you would split
key again, and so on. The crucial point is that you never use the same PRNGKey twice. Since
split() takes a key as its argument, we must throw away that old key when we split it.
It doesn’t matter which part of the output of
split(key) we call
key, and which we call
subkey. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they’re consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.
Usually, the above example would be written concisely as
key, subkey = random.split(key)
which discards the old key automatically.
It’s worth noting that
split() can create as many keys as you need, not just 2:
key, *forty_two_subkeys = random.split(key, num=43)
Another difference between NumPy’s and JAX’s random modules relates to the sequential equivalence guarantee mentioned above.
As in NumPy, JAX’s random module also allows sampling of vectors of numbers. However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).
In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying
key = random.PRNGKey(42) subkeys = random.split(key, 3) sequence = np.stack([random.normal(subkey) for subkey in subkeys]) print("individually:", sequence) key = random.PRNGKey(42) print("all at once: ", random.normal(key, shape=(3,)))
individually: [-0.04838839 0.10796146 -1.2226542 ] all at once: [ 0.18693541 -1.2806507 -1.5593133 ]
Note that contrary to our recommendation above, we use
key directly as an input to
random.normal() in the second example. This is because we won’t reuse it anywhere else, so we don’t violate the single-use principle.