jax.random.key

Contents

jax.random.key#

jax.random.key(seed, *, impl=None)[source]#

Create a pseudo-random number generator (PRNG) key given an integer seed.

The result is a scalar array with a key that indicates the default PRNG implementation, as determined by the optional impl argument or, otherwise, by the jax_default_prng_impl config flag.

Parameters:
  • seed (int | ArrayLike) – a 64- or 32-bit integer used as the value of the key.

  • impl (PRNGSpecDesc | None) – optional string specifying the PRNG implementation (e.g. 'threefry2x32')

Returns:

A scalar PRNG key array, consumable by random functions as well as split and fold_in.

Return type:

KeyArray