jax.random.wrap_key_data

Contents

jax.random.wrap_key_data#

jax.random.wrap_key_data(key_bits_array, *, impl=None)[source]#

Wrap an array of key data bits into a PRNG key array.

Parameters:
  • key_bits_array (Array) – a uint32 array with trailing shape corresponding to the key shape of the PRNG implementation specified by impl.

  • impl (str | PRNGSpec | PRNGImpl | None) – optional, specifies a PRNG implementation, as in random.key.

Returns:

A PRNG key array, whose dtype is a subdtype of jax.dtypes.prng_key

corresponding to impl, and whose shape equals the leading shape of key_bits_array.shape up to the key bit dimensions.