pmap is pretty slow for functions involving jax.random
See original GitHub issueThe following repro script shows the slowness for the issue in the title. This happens in CPU and might happen in GPU too (but I don’t have 2 GPUs to test).
import jax.numpy as np
from jax import pmap, jit, random
# requires the flag: XLA_FLAGS=--xla_force_host_platform_device_count=12
from jax.config import config; config.update('jax_platform_name', 'cpu')
n = 10
@jit
def g(rng, x):
key, subkey = random.split(rng)
return key, x + 0.01 * random.normal(subkey)
def f(x):
rng = random.PRNGKey(0)
collection = []
for i in range(n):
rng, x = g(rng, x)
collection.append(x)
return np.stack(collection)
pmap(f)(np.array([0., 1.]))
cc @neerajprad
Issue Analytics
- State:
- Created 4 years ago
- Comments:8 (8 by maintainers)
Top Results From Across the Web
JAX pmap is slower than jit(vmap), how to speedup?
Surprisingly the pmap -ed version is much slower. I know that doubling the performance is almost impossible but I expected a better performance....
Read more >JAX Frequently Asked Questions (FAQ)
jit decorated function is very slow to compile#. If your jit decorated function takes tens of seconds (or more!) to run the first...
Read more >Generating random numbers with `jax.random.split` can be ...
Currently, I'm relying on jax.random.split and jax.random.normal ... I expect generating random numbers with this combination to be slower, ...
Read more >Writing an RL Environment in JAX - Medium
First, we create a key using jax.random.PRNGKey function. The reason we need to do this explicitly is that for the random number generation ......
Read more >Why You Should (or Shouldn't) be Using Google's JAX in 2022
Should you be using Google's JAX in 2022? Check out our recommendations on using JAX for Deep Learning and more!
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
By the way, AIUI none of the XLA collective operations (for cross-device communication, like all-reduces) are implemented on XLA:CPU, so a lot of pmap features (including nested pmaps) won’t work on XLA:CPU yet. Several are being added to XLA:GPU now though, so there’s a lot of forward progress.
If you end up wanting a collective primitive on XLA:CPU, just open an issue and we’ll work with the XLA folks on it!
Thanks @mattjj! I can make it fast now with
Currently, it works well for some toy models, so I will close this issue. If I find some performance issue with larger models, I’ll open a separated issue. Thanks again!