jax.random.wrap_key_data#
- jax.random.wrap_key_data(key_bits_array, *, impl=None)[source]#
Wrap an array of key data bits into a PRNG key array.
- Parameters:
- Returns:
- A PRNG key array, whose dtype is a subdtype of
jax.dtypes.prng_key
corresponding to
impl
, and whose shape equals the leading shape ofkey_bits_array.shape
up to the key bit dimensions.
- A PRNG key array, whose dtype is a subdtype of