[Bug?] Jax shuffle speed versus numpy random shuffle speed
See original GitHub issueI was reexamining a training loop I’d made in jax for a talk, and noticed that when I tried to demonstrate that using jax.random.shuffle
gives a speedup over using numpy.random.shuffle
, I found that the latter was actually significantly faster. I then went and recreated it in a small script, and I’ve found that the jax shuffle method gets really slow for ndarrays of around 1e6 and 1e7, both of which are sizes we easily encounter.
The demonstration code and speeds are in this gist, but basically at around 1e7 numpy.random.shuffle
takes a neat 3 seconds while jax.random.shuffle
clocks in at over 4 minutes!
Is this intentional? Is it unavoidable due to not being able to do it in-place? I tried reading into the JAX shuffle implementation a bit and it seems like a lot of effort goes into making it rather fast, and the numpy random shuffle has the advantage of doing everything in place, so that would all make sense to me.
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (3 by maintainers)
Revisiting this:
jax.random.shuffle
has been deprecated in favor ofjax.random.permutation
, but the performance issue remains in the current main branch of JAX (run on a Colab CPU):On GPU, JAX looks better:
I’m going to assign to @froystig as he’s been working on PRNG updates and may have ideas about how to improve performance.
Keen to hear if there is any update on this. (Happy to also help if directed!!)