JAX slow compared to numba for certain calculations
See original GitHub issueHi there,
I recently coded up a version of the Glicko rating system and found that for the particular calculations involved, using jit
decorators in JAX runs slowly, while the same strategy works well in numba
. I put together a benchmark and notes here:
https://github.com/martiningram/jax_vs_numba_glicko
Naively using jit
decorators runs the benchmark in 20s in JAX, compared to 121ms in numba
, so clearly that’s not a good strategy with JAX. @mattjj asked me to share this here, since while this is not a particular workload you have been looking at with JAX, it could be good to understand why it is slow and perhaps how to improve things. I hope this is useful!
Issue Analytics
- State:
- Created 4 years ago
- Reactions:2
- Comments:5 (1 by maintainers)
Top Results From Across the Web
JAX(XLA) vs Numba(LLVM) Reduction - python - Stack Overflow
You'll see that for these microbenchmarks, JAX is a few milliseconds slower than both numpy and numba. So does this mean JAX is...
Read more >Performance Tips — Numba 0.50.1 documentation
A reasonably effective approach to achieving high performance code is to profile the code running with real data and use that to guide...
Read more >Faster Python calculations with Numba | Hacker News
JAX is actually lower level than deep learning (despite including some specialized constructs) which makes it an almost drop-in replacement for numpy that...
Read more >Hard to beat Numba / JAX loop for generating Mandelbrot
I though fma was for accuracy, and muladd was for performance: fma(x, y, z) Computes x*y+z without rounding the intermediate result x*y. On...
Read more >Why You Should (or Shouldn't) be Using Google's JAX in 2022
JAX performs this calculation in only 5.54 ms - over 86 times faster than NumPy. ... that's 8,600%. JAX has the potential to...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Somewhat off topic for the JAX issue tracker, but this is a pretty ideal case for the Julia compiler, which can run
calculate_approximate_likelihood
in about 400ns, or a little less time than it takes togenerate_random_observations
: https://github.com/martiningram/jax_vs_numba_glicko/pull/1. Since all three of XLA/Numba/Julia are using LLVM, all this really means is that JAX/XLA has even more room to grow on heavily-scalar applications like this one; I don’t see any reason why we wouldn’t (eventually, in principle) be able to generate ~the same code Julia does here.Closing, as I don’t think there is any action item here. Ongoing work on dynamic shapes may improve the situation with respect to shape-agnostic JIT compilation, but that’s tracked elsewhere.