question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

pmap seems to drastically improve performance in the example notebook

See original GitHub issue

Description

I am running blackjax in WSL2 on a 32 core CPU.

I was playing with the example notebook (https://blackjax-devs.github.io/blackjax/examples/Introduction.html#), and noticed that the CPU utilization is actually quite low when running multiple chains.

I have modified the code by first running

import numpyro as npr
npr.util.set_host_device_count(32)

Then I re-used the inference loop from the single-chain example, but instead of using vmap I used pmap to parallelize the execution:

rng_key = jax.random.PRNGKey(0)

keys = jax.random.split(rng_key, num_chains)
inference_loop = jax.pmap(
    inference_loop, in_axes=(0, 0, None, None), static_broadcasted_argnums=(2, 3)
)

states = inference_loop(keys, initial_states, nuts.step, 1_000)

And this seems to cut the running from 2 minutes to 3 seconds

# vmap
Wall time: 2min 10s
# pmap
Wall time: 2.91 s

Am I doing something wrong here, or should the example actually be adjusted to use pmap?

Reproducing

See full notebooks here:
https://gist.github.com/elanmart/810f1964738b0ddd8f108b17b7969f82

Setup

Python implementation: CPython
Python version       : 3.9.12
IPython version      : 8.4.0

jax     : 0.3.14
jaxlib  : 0.3.14
blackjax: 0.8.2

Compiler    : GCC 7.5.0
OS          : Linux
Release     : 5.10.102.1-microsoft-standard-WSL2
Machine     : x86_64
Processor   : x86_64
CPU cores   : 32
Architecture: 64bit

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
elanmartcommented, Jul 11, 2022

Thanks a lot for the explanation! I’ll be happy to open a PR adding a small section with pmap

0reactions
rloufcommented, Aug 29, 2022

No problem! Thank you for letting us know.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Tips to improve performance - Visual Studio (Windows)
If you are typically running low on memory during debugging sessions, you can optimize performance by making one or more configuration changes.
Read more >
Problem: Performance issues with ArcGIS Desktop 10.x
There are several instances when ArcGIS Desktop applications demonstrate slow performance; for example when launching ArcMap or ArcCatalog, ...
Read more >
Boost your computer's performance - Journal of Accountancy
This article provides more than a dozen tips to help transform your computer into a lean, clean performance machine.
Read more >
8 Tools to Increase Hard Drive Performance on Windows
Windows slows down over time, but you can improve the speed and efficiency of your hard drive with these HDD optimization apps.
Read more >
Best practices for scene performance—ArcGIS Online Help
Learn how to improve scene performance. ... a scene with thousands of buildings, for example, can drastically slow performance when shadows are rendered....
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found