question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

JAX slow compared to numba for certain calculations

See original GitHub issue

Hi there,

I recently coded up a version of the Glicko rating system and found that for the particular calculations involved, using jit decorators in JAX runs slowly, while the same strategy works well in numba. I put together a benchmark and notes here:

https://github.com/martiningram/jax_vs_numba_glicko

Naively using jit decorators runs the benchmark in 20s in JAX, compared to 121ms in numba, so clearly that’s not a good strategy with JAX. @mattjj asked me to share this here, since while this is not a particular workload you have been looking at with JAX, it could be good to understand why it is slow and perhaps how to improve things. I hope this is useful!

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:2
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

14reactions
jekbradburycommented, Jun 7, 2019

Somewhat off topic for the JAX issue tracker, but this is a pretty ideal case for the Julia compiler, which can run calculate_approximate_likelihood in about 400ns, or a little less time than it takes to generate_random_observations: https://github.com/martiningram/jax_vs_numba_glicko/pull/1. Since all three of XLA/Numba/Julia are using LLVM, all this really means is that JAX/XLA has even more room to grow on heavily-scalar applications like this one; I don’t see any reason why we wouldn’t (eventually, in principle) be able to generate ~the same code Julia does here.

0reactions
jakevdpcommented, Jun 21, 2022

Closing, as I don’t think there is any action item here. Ongoing work on dynamic shapes may improve the situation with respect to shape-agnostic JIT compilation, but that’s tracked elsewhere.

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX(XLA) vs Numba(LLVM) Reduction - python - Stack Overflow
You'll see that for these microbenchmarks, JAX is a few milliseconds slower than both numpy and numba. So does this mean JAX is...
Read more >
Performance Tips — Numba 0.50.1 documentation
A reasonably effective approach to achieving high performance code is to profile the code running with real data and use that to guide...
Read more >
Faster Python calculations with Numba | Hacker News
JAX is actually lower level than deep learning (despite including some specialized constructs) which makes it an almost drop-in replacement for numpy that...
Read more >
Hard to beat Numba / JAX loop for generating Mandelbrot
I though fma was for accuracy, and muladd was for performance: fma(x, y, z) Computes x*y+z without rounding the intermediate result x*y. On...
Read more >
Why You Should (or Shouldn't) be Using Google's JAX in 2022
JAX performs this calculation in only 5.54 ms - over 86 times faster than NumPy. ... that's 8,600%. JAX has the potential to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found