`random.split` should support splitting starting at a particular index, such as `step` number
See original GitHub issueBackground: during training, we can deterministically derive the PRNG seed from the current step number, by seed=hash(step_number)
. This can sometimes be convenient because it avoids the need to checkpoint the PRNG seed, as the step number is all the state that is needed.
The computation seed=hash(step_number)
can be implemented, albeit inefficiently, in JAX:
key_this_step = jax.random.split(global_key, step_number + 1)[step number]
Sadly, this is inefficient in that it materializes an intermediate array of length step_number + 1
, even though we only want to access the last element of that array. Given that all of the elements of jax.random.split
are computed in parallel anyway, it would be nice to directly have access to the n
th split of jax.random.split
rahter than having to produce all of them and selecting one.
One way to integrate this into the API is to add a parameter slice=None
to jax.random.split
, whose behavior is defined by ensuring that the following equivalence holds:
jax.random.split(key, slice=(lo, hi)) == jax.random.split(key, num)[lo:hi]
I note that there are other requests for extensions to jax.random.split
, such as https://github.com/google/jax/issues/4013. Another alternative, which seems flexible enough to serve everyone’s need, would be to directly expose the underlying hash function as a function jax.random.hash(key, index)
, defined so that the following equivalence holds:
jax.random.split(key, num) == jax.random.hash(key, jax.lax.iota(jnp.uint32, num))
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (4 by maintainers)
random.fold_in(key, step)
looks great, thanks for pointing me to i!Thanks for the suggestion! This is a really good time to bring this up, because @froystig has just finished a refactoring of the PRNG code so that users can override the API, including
_split
. This means that if we are planning to make any changes to what options_split
might accept, now would be the time (before downstream libraries implement their own random generators with their ownsplit
methods).I wonder if @froystig has thoughts about this request, and the one in #4013?