RNGs: key types and custom implementations
See original GitHub issueThis issue tracks the introduction of typed random key arrays with pluggable PRNG implementations.
Work on this began with #6899, which mostly works out the “pluggable implementations” part, makes initial progress on the “typed” part, and sets things up throughout the codebase for backwards-compatibility and upgrading. The main configuration flag guarding the system upgrade is config.jax_enable_custom_prng
.
The motivation for pluggability is straightforward: we’d like to set up various implementations for random bit generation, and still have all of jax.random
work on top of any one of them. While we’re at it, we could make this extensible.
The motivation for having keys reflected in types (somehow, at some level) is broadly to improve safety, and with it, to offer both users and the system some structural guarantees. Desiderata include:
- Moving to a user-facing representation of key arrays that reflects that they are indeed key arrays. Currently key arrays are plain
uint32
arrays to users, indistinguishable from any other data. - Restricting operations on key arrays. The current plain array representation allows for key-invalidating operations (e.g. manual updating, addition, …). We’d like to disallow these, or at least render mistakes unlikely.
- More/better opportunities to check key misuse, reuse, and so on.
Unsurprisingly, key types also help for mapping keys to RNG implementations (for the pluggability part above).
There are roughly two ways we’ve thought to approach endowing jax with key types, specifically, key-element-type arrays:
One is as a frontend component handled during staging, projected away to plain u32
in staged IR, in analogy with how pytrees behave. To introduce a pytree type in Python that wraps an underlying key-data (say, uint32
) array is not quite correct, since it would misbehave under vmap
, scan
, and even jax.tree_map
. Instead, we might want to rely on something like the typeclass mechanisms still in development (e.g. vmappable
, #8451) for this approach.
The other tack is to introduce key types into our IR and other internal machinery (with a corresponding lowering), and to map to that from a Python array-of-keys-like type during staging. This might confer some extra advantages downstream as well, such as more opportunities for checking throughout our front- and middle-end. This is roughly the approach we’re taking (as of around summer '22).
Ahead of the complete upgrade, some by-products are already available. We’ve implemented a couple of PRNGs that serve as alternatives to the default threefry2x32
implementation (in #8067, #8123) and build on compiler bit generator primitives. The process-wide default RNG implementation can be controlled via jax.default_prng_impl
and config.jax_default_prng_impl
(from #8135) and can be accessed via jax.random.default_prng_impl
(from #9186). These can be used to swap between the pre-defined RNG alternatives for the entire process at a time.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:3
- Comments:9 (8 by maintainers)
Yes, very good point.
Sounds great. And I love the new RNG interface if it wasn’t obvious 😄
We plan on starting this migration once
PRNGKeyArray
is no longer a PyTree, ie. when we can register it as a JAX-type through the upcoming typeclass mechanism (likevmappable
https://github.com/google/jax/pull/8451)