question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Limit jax multithreading

See original GitHub issue

By default jax appears to multithread most operations, eg.

x = jr.normal(jrkey, shape=(50000, 50000))
x @ x

will run across all available cores. This is great in general, and matches numpy’s behavior. But it presents problems when trying to run a bunch of small operations in parallel, eg. running the same script initialized with 4 different random seeds on a 4-core machine.

Is there any option in jax to cap the number of threads that it uses? Something like https://stackoverflow.com/questions/17053671/python-how-do-you-stop-numpy-from-multithreading?

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:8 (7 by maintainers)

github_iconTop GitHub Comments

14reactions
samuelacommented, Sep 30, 2019

For my own (and other’s) future googling, my current approach looks like

from multiprocessing import get_context
import os

# Limit ourselves to single-threaded jax/xla operations to avoid thrashing. See
# https://github.com/google/jax/issues/743.
os.environ["XLA_FLAGS"] = ("--xla_cpu_multi_thread_eigen=false "
                           "intra_op_parallelism_threads=1")

def job(random_seed: int):
  # jax jax jax

if __name__ == "__main__":
  # See https://codewithoutrules.com/2018/09/04/python-multiprocessing/.
  with get_context("spawn").Pool() as pool:
    pool.imap_unordered(job, range(100))

There may be a way better way, but it seems to work 🤷‍♀️

2reactions
mattjjcommented, Jul 16, 2019

Did you try setting these environment variables? (My comment didn’t explain this very well.)

XLA_FLAGS="--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1" python my_file.py

That seems to work for me in a test, at least for a big matmul.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Concurrency - JAX documentation - Read the Docs
Concurrency#. JAX has limited support for Python concurrency. Clients may call JAX APIs (e.g., jit() or grad() ) concurrently from separate Python threads....
Read more >
Limit number of threads in numpy - python - Stack Overflow
So I know it is using blas, but I can't figure out how to make it use 1 thread for matrix multiplication. python...
Read more >
Configuring Thread Pools for Java Web Servers - Baeldung
In this tutorial, we take a look at thread pool configuration for Java web application servers such as Apache Tomcat, Glassfish Server, and ......
Read more >
How to Run Multiple Threads Concurrently in Java ...
If you want to return an value or throw an exception then use Callable otherwise use Runnable as extending Thread class limits the...
Read more >
14. Parallelization
Native Python struggles to implement multithreading due to some legacy ... But this is not a restriction for scientific libraries like NumPy and...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found