question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Generating random numbers with `jax.random.split` can be >200x slower than `np.random.normal`

See original GitHub issue

Currently, I’m relying on jax.random.split and jax.random.normal for random number generation. I expect generating random numbers with this combination to be slower, but it’s still surprising given the following results (on CPU):

import time
from jax import random, grad
import jax.numpy as np
import numpy.random as npr

def sample_repeatedly_with_split(key):
    for _ in range(10000):
        key, subkey = random.split(key)
        random.normal(subkey, shape=(3,))


def sample_repeatedly():
    for _ in range(10000):
        npr.normal(size=(3,))


key = random.PRNGKey(0)
now = time.time()
sample_repeatedly_with_split(key=key)
print('sample with split takes {:.4f} secs'.format(time.time() - now))

now = time.time()
sample_repeatedly()
print('`npr.normal` takes {:.4f} secs'.format(time.time() - now))

Results:

sample with split takes 8.6022 secs
`npr.normal` takes 0.0296 secs

Some profiling results (with cProfile and pstats) show:

myscript.cprof% stats 20
Wed Jul  3 10:13:14 2019    myscript.cprof

         7879396 function calls (7264157 primitive calls) in 8.870 seconds

   Ordered by: cumulative time
   List reduced from 1909 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    293/1    0.002    0.000    8.877    8.877 {built-in method builtins.exec}
        1    0.000    0.000    8.877    8.877 split_prof.py:1(<module>)
        1    0.091    0.091    8.602    8.602 split_prof.py:12(sample_repeatedly_with_split)
20003/20000    0.197    0.000    4.365    0.000 api.py:109(f_jitted)
    30000    0.034    0.000    4.076    0.000 xla.py:518(<genexpr>)
    20001    0.102    0.000    4.042    0.000 lax_numpy.py:2161(_rewriting_take)
    20004    0.089    0.000    3.580    0.000 lax.py:1206(index_in_dim)
20007/20000    0.143    0.000    3.062    0.000 core.py:656(call_bind)
20003/20000    0.090    0.000    2.574    0.000 xla.py:604(xla_call_impl)
40274/40273    0.081    0.000    2.410    0.000 core.py:139(bind)
    10000    0.028    0.000    2.295    0.000 random.py:376(normal)
    10000    0.021    0.000    2.121    0.000 random.py:161(split)
    40006    0.134    0.000    2.114    0.000 xla.py:50(apply_primitive)
    20005    0.048    0.000    1.781    0.000 lax.py:1192(slice_in_dim)
    20009    0.271    0.000    1.733    0.000 lax.py:586(slice)
20003/20000    0.085    0.000    1.660    0.000 linear_util.py:199(memoized_fun)
    40006    0.983    0.000    1.318    0.000 xla.py:83(execute_compiled_primitive)
    20013    0.128    0.000    1.284    0.000 lax.py:549(reshape)
    20003    0.503    0.000    0.691    0.000 xla.py:629(execute_compiled)
    59994    0.116    0.000    0.598    0.000 linear_util.py:173(__eq__)

Note I named my script split_prof.py. It seems there’s considerable overhead with XLA, even when I’m not actively jitting any functions.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Jul 3, 2019

Thanks for the clear benchmark!

I think this is essentially timing dispatch overheads. NumPy’s dispatch overheads are a lot lower than JAX’s, which makes it much faster at doing lots of small operations. (Interestingly, lots of small operations is what Python+NumPy is already considered bad at compared to something like pure C. One way to think about JAX, at least in its current state, is that it pushes that contrast further, in that it’s even better than NumPy at large array-oriented operations because of its jit compilation and use of accelerators, but it’s even worse at doing lots of small operations.)

One way to make things faster is to use jit, e.g.:

from jax import jit

@jit
def split_and_sample(key):
  key, subkey = random.split(key)
  val = random.normal(subkey, shape=(3,))
  return key, val

def sample_repeatedly_with_split(key):
  for _ in range(10000):
    key, _ = split_and_sample(key)

That sped things up, but only by a factor of 2. (That’s also including compilation time, though that’s probably small.)

To measure something other than dispatch overheads, which isn’t specific to PRNG stuff but would be measurable in pretty much any JAX vs NumPy micro benchmark like this, we can make the arrays bigger. Here are a few different sizes, with the largest being 30000 (which don’t forget is smaller than a 200x200 array, which has size 40000, so these aren’t very big sizes):

issue968

Here’s the full script (check for bugs!):

import time
from jax import random, grad, jit
import jax.numpy as np
import numpy.random as npr

@jit
def split_and_sample(key):
  key, subkey = random.split(key)
  val = random.normal(subkey, shape=shape)
  return key, val

def sample_repeatedly_with_split(key):
    for _ in range(10000):
        key, _ = split_and_sample(key)
    return key


def sample_repeatedly():
    for _ in range(10000):
        npr.normal(size=shape)


jax_times, np_times = [], []
sizes = [3, 30, 300, 3000, 30000]
for size in sizes:
  shape = (size,)

  key = random.PRNGKey(0)
  now = time.time()
  sample_repeatedly_with_split(key=key).block_until_ready()  # async!
  jax_times.append(time.time() - now)

  now = time.time()
  sample_repeatedly()
  np_times.append(time.time() - now)

import matplotlib.pyplot as plt
plt.semilogy(sizes, jax_times, label="jax times")
plt.semilogy(sizes, np_times, label="np times")
plt.legend()
plt.savefig('issue968.png')

The block_until_ready function is to prevent JAX from getting an unfair advantage due to its async dispatch, which lets us hide dispatch overheads and device latencies behind the real numerical work going on ~(but in this case it doesn’t make a difference because all the time is spent in Python overheads, so overlapping the compute with the Python doesn’t buy us anything)~. (EDITED)

Still, if you want to generate lots of small arrays and can’t stick everything under a jit, that’s the kind of workload for which NumPy is better than JAX.

What do you think?

0reactions
mattjjcommented, Jul 15, 2019

(It might be best to open a separate issue thread.)

Indeed, not everything can be jit, and this is a good example of code that JAX can’t jit and must execute in an “op-by-op” fashion.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Pseudo Random Numbers in JAX
In NumPy, pseudo random number generation is based on a global state . This can be set to a deterministic initial condition using...
Read more >
Getting started with JAX - Towards Data Science
JAX automatically detects whether you have access to a GPU or TPU. And here is also the first difference to classic NumPy. We...
Read more >
MCMC in JAX with benchmarks: 3 ways to write a sampler
The issue is that for jit to work, you can't have NumPy arrays or use the NumPy random number generator ( np.random.normal() )....
Read more >
NumPyro documentation - Pyro
JAX does not have a global random state, and as such, distribution samplers need an explicit random number generator key. (PRNGKey) to generate...
Read more >
arXiv:1912.11554v1 [stat.ML] 24 Dec 2019
the hood, inference algorithms can use effect handlers to inspect and ... PyTorch, JAX uses a functional pseudo-random number generator [13] ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found