jax.lax.rng_bit_generator

jax.lax.rng_bit_generator#

jax.lax.rng_bit_generator(key, shape, dtype=<class 'numpy.uint32'>, algorithm=jaxlib.xla_extension.ops.RandomAlgorithm.RNG_DEFAULT)[source]#

Stateless PRNG bit generator. Experimental and its use is discouraged.

Returns uniformly distributed random bits with the specified shape and dtype (what is required to be an integer type) using the platform specific default algorithm or the one specified.

It provides direct access to the RngBitGenerator primitive exposed by XLA (https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator) for low level API access.

Most users should use jax.random instead for a stable and more user friendly API.