question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

pmap is pretty slow for functions involving jax.random

See original GitHub issue

The following repro script shows the slowness for the issue in the title. This happens in CPU and might happen in GPU too (but I don’t have 2 GPUs to test).

import jax.numpy as np

from jax import pmap, jit, random
# requires the flag: XLA_FLAGS=--xla_force_host_platform_device_count=12
from jax.config import config; config.update('jax_platform_name', 'cpu')

n = 10

@jit
def g(rng, x):
    key, subkey = random.split(rng)
    return key, x + 0.01 * random.normal(subkey)

def f(x):
    rng = random.PRNGKey(0)
    collection = []
    for i in range(n):
        rng, x = g(rng, x)
        collection.append(x)
    return np.stack(collection)

pmap(f)(np.array([0., 1.]))

cc @neerajprad

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Jun 7, 2019

By the way, AIUI none of the XLA collective operations (for cross-device communication, like all-reduces) are implemented on XLA:CPU, so a lot of pmap features (including nested pmaps) won’t work on XLA:CPU yet. Several are being added to XLA:GPU now though, so there’s a lot of forward progress.

If you end up wanting a collective primitive on XLA:CPU, just open an issue and we’ll work with the XLA folks on it!

1reaction
fehiepsicommented, Jun 7, 2019

Thanks @mattjj! I can make it fast now with

XLA_FLAGS="--xla_force_host_platform_device_count=12 --xla_cpu_multi_thread
_eigen=False --intra_op_parallelism_threads=1"

Currently, it works well for some toy models, so I will close this issue. If I find some performance issue with larger models, I’ll open a separated issue. Thanks again!

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX pmap is slower than jit(vmap), how to speedup?
Surprisingly the pmap -ed version is much slower. I know that doubling the performance is almost impossible but I expected a better performance....
Read more >
JAX Frequently Asked Questions (FAQ)
jit decorated function is very slow to compile#. If your jit decorated function takes tens of seconds (or more!) to run the first...
Read more >
Generating random numbers with `jax.random.split` can be ...
Currently, I'm relying on jax.random.split and jax.random.normal ... I expect generating random numbers with this combination to be slower, ...
Read more >
Writing an RL Environment in JAX - Medium
First, we create a key using jax.random.PRNGKey function. The reason we need to do this explicitly is that for the random number generation ......
Read more >
Why You Should (or Shouldn't) be Using Google's JAX in 2022
Should you be using Google's JAX in 2022? Check out our recommendations on using JAX for Deep Learning and more!
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found