Generating random numbers with `jax.random.split` can be >200x slower than `np.random.normal`
See original GitHub issueCurrently, I’m relying on jax.random.split
and jax.random.normal
for random number generation. I expect generating random numbers with this combination to be slower, but it’s still surprising given the following results (on CPU):
import time
from jax import random, grad
import jax.numpy as np
import numpy.random as npr
def sample_repeatedly_with_split(key):
for _ in range(10000):
key, subkey = random.split(key)
random.normal(subkey, shape=(3,))
def sample_repeatedly():
for _ in range(10000):
npr.normal(size=(3,))
key = random.PRNGKey(0)
now = time.time()
sample_repeatedly_with_split(key=key)
print('sample with split takes {:.4f} secs'.format(time.time() - now))
now = time.time()
sample_repeatedly()
print('`npr.normal` takes {:.4f} secs'.format(time.time() - now))
Results:
sample with split takes 8.6022 secs
`npr.normal` takes 0.0296 secs
Some profiling results (with cProfile and pstats) show:
myscript.cprof% stats 20
Wed Jul 3 10:13:14 2019 myscript.cprof
7879396 function calls (7264157 primitive calls) in 8.870 seconds
Ordered by: cumulative time
List reduced from 1909 to 20 due to restriction <20>
ncalls tottime percall cumtime percall filename:lineno(function)
293/1 0.002 0.000 8.877 8.877 {built-in method builtins.exec}
1 0.000 0.000 8.877 8.877 split_prof.py:1(<module>)
1 0.091 0.091 8.602 8.602 split_prof.py:12(sample_repeatedly_with_split)
20003/20000 0.197 0.000 4.365 0.000 api.py:109(f_jitted)
30000 0.034 0.000 4.076 0.000 xla.py:518(<genexpr>)
20001 0.102 0.000 4.042 0.000 lax_numpy.py:2161(_rewriting_take)
20004 0.089 0.000 3.580 0.000 lax.py:1206(index_in_dim)
20007/20000 0.143 0.000 3.062 0.000 core.py:656(call_bind)
20003/20000 0.090 0.000 2.574 0.000 xla.py:604(xla_call_impl)
40274/40273 0.081 0.000 2.410 0.000 core.py:139(bind)
10000 0.028 0.000 2.295 0.000 random.py:376(normal)
10000 0.021 0.000 2.121 0.000 random.py:161(split)
40006 0.134 0.000 2.114 0.000 xla.py:50(apply_primitive)
20005 0.048 0.000 1.781 0.000 lax.py:1192(slice_in_dim)
20009 0.271 0.000 1.733 0.000 lax.py:586(slice)
20003/20000 0.085 0.000 1.660 0.000 linear_util.py:199(memoized_fun)
40006 0.983 0.000 1.318 0.000 xla.py:83(execute_compiled_primitive)
20013 0.128 0.000 1.284 0.000 lax.py:549(reshape)
20003 0.503 0.000 0.691 0.000 xla.py:629(execute_compiled)
59994 0.116 0.000 0.598 0.000 linear_util.py:173(__eq__)
Note I named my script split_prof.py
. It seems there’s considerable overhead with XLA, even when I’m not actively jitting any functions.
Issue Analytics
- State:
- Created 4 years ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
Pseudo Random Numbers in JAX
In NumPy, pseudo random number generation is based on a global state . This can be set to a deterministic initial condition using...
Read more >Getting started with JAX - Towards Data Science
JAX automatically detects whether you have access to a GPU or TPU. And here is also the first difference to classic NumPy. We...
Read more >MCMC in JAX with benchmarks: 3 ways to write a sampler
The issue is that for jit to work, you can't have NumPy arrays or use the NumPy random number generator ( np.random.normal() )....
Read more >NumPyro documentation - Pyro
JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key. (PRNGKey) to generate...
Read more >arXiv:1912.11554v1 [stat.ML] 24 Dec 2019
the hood, inference algorithms can use effect handlers to inspect and ... PyTorch, JAX uses a functional pseudo-random number generator [13] ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for the clear benchmark!
I think this is essentially timing dispatch overheads. NumPy’s dispatch overheads are a lot lower than JAX’s, which makes it much faster at doing lots of small operations. (Interestingly, lots of small operations is what Python+NumPy is already considered bad at compared to something like pure C. One way to think about JAX, at least in its current state, is that it pushes that contrast further, in that it’s even better than NumPy at large array-oriented operations because of its jit compilation and use of accelerators, but it’s even worse at doing lots of small operations.)
One way to make things faster is to use
jit
, e.g.:That sped things up, but only by a factor of 2. (That’s also including compilation time, though that’s probably small.)
To measure something other than dispatch overheads, which isn’t specific to PRNG stuff but would be measurable in pretty much any JAX vs NumPy micro benchmark like this, we can make the arrays bigger. Here are a few different sizes, with the largest being 30000 (which don’t forget is smaller than a 200x200 array, which has size 40000, so these aren’t very big sizes):
Here’s the full script (check for bugs!):
The
block_until_ready
function is to prevent JAX from getting an unfair advantage due to its async dispatch, which lets us hide dispatch overheads and device latencies behind the real numerical work going on ~(but in this case it doesn’t make a difference because all the time is spent in Python overheads, so overlapping the compute with the Python doesn’t buy us anything)~. (EDITED)Still, if you want to generate lots of small arrays and can’t stick everything under a
jit
, that’s the kind of workload for which NumPy is better than JAX.What do you think?
(It might be best to open a separate issue thread.)
Indeed, not everything can be
jit
, and this is a good example of code that JAX can’tjit
and must execute in an “op-by-op” fashion.