question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

RNGs: key types and custom implementations

See original GitHub issue

This issue tracks the introduction of typed random key arrays with pluggable PRNG implementations.

Work on this began with #6899, which mostly works out the “pluggable implementations” part, makes initial progress on the “typed” part, and sets things up throughout the codebase for backwards-compatibility and upgrading. The main configuration flag guarding the system upgrade is config.jax_enable_custom_prng.

The motivation for pluggability is straightforward: we’d like to set up various implementations for random bit generation, and still have all of jax.random work on top of any one of them. While we’re at it, we could make this extensible.

The motivation for having keys reflected in types (somehow, at some level) is broadly to improve safety, and with it, to offer both users and the system some structural guarantees. Desiderata include:

  • Moving to a user-facing representation of key arrays that reflects that they are indeed key arrays. Currently key arrays are plain uint32 arrays to users, indistinguishable from any other data.
  • Restricting operations on key arrays. The current plain array representation allows for key-invalidating operations (e.g. manual updating, addition, …). We’d like to disallow these, or at least render mistakes unlikely.
  • More/better opportunities to check key misuse, reuse, and so on.

Unsurprisingly, key types also help for mapping keys to RNG implementations (for the pluggability part above).

There are roughly two ways we’ve thought to approach endowing jax with key types, specifically, key-element-type arrays:

One is as a frontend component handled during staging, projected away to plain u32 in staged IR, in analogy with how pytrees behave. To introduce a pytree type in Python that wraps an underlying key-data (say, uint32) array is not quite correct, since it would misbehave under vmap, scan, and even jax.tree_map. Instead, we might want to rely on something like the typeclass mechanisms still in development (e.g. vmappable, #8451) for this approach.

The other tack is to introduce key types into our IR and other internal machinery (with a corresponding lowering), and to map to that from a Python array-of-keys-like type during staging. This might confer some extra advantages downstream as well, such as more opportunities for checking throughout our front- and middle-end. This is roughly the approach we’re taking (as of around summer '22).

Ahead of the complete upgrade, some by-products are already available. We’ve implemented a couple of PRNGs that serve as alternatives to the default threefry2x32 implementation (in #8067, #8123) and build on compiler bit generator primitives. The process-wide default RNG implementation can be controlled via jax.default_prng_impl and config.jax_default_prng_impl (from #8135) and can be accessed via jax.random.default_prng_impl (from #9186). These can be used to swap between the pre-defined RNG alternatives for the entire process at a time.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Reactions:3
  • Comments:9 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
NeilGirdharcommented, Sep 28, 2022

So long as any amount of numpy is supported, I’m wary of possible ambiguity.

Yes, very good point.

I otherwise appreciate your case in favor of it. We can start with less for now and see whether we’re drawn to add it.

Sounds great. And I love the new RNG interface if it wasn’t obvious 😄

1reaction
LenaMartenscommented, Jan 21, 2022

We plan on starting this migration once PRNGKeyArray is no longer a PyTree, ie. when we can register it as a JAX-type through the upcoming typeclass mechanism (like vmappable https://github.com/google/jax/pull/8451)

Read more comments on GitHub >

github_iconTop Results From Across the Web

RNGs - A Few Thoughts on Cryptographic Engineering
Back in 2008 it was Debian, with their 'custom' OpenSSL implementation that could only produce 32,768 possible TLS keys (do you really need...
Read more >
jax.random.PRNGKey - JAX documentation - Read the Docs
Create a pseudo-random number generator (PRNG) key given an integer seed. The resulting key carries the default PRNG implementation, as determined by the...
Read more >
Configure ASP.NET Core Data Protection - Microsoft Learn
Learn how to configure Data Protection in ASP.NET Core.
Read more >
Demystifying KMS keys operations, bring your own key (BYOK ...
The AWS Encryption SDK uses key rings, a tool that will simplify ... requires you to re-encrypt data for native KMS implementations.
Read more >
Random Numbers - Julia Documentation
The provided RNGs can generate uniform random numbers of the following types: Float16 , Float32 , Float64 , BigFloat , Bool , Int8...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found