[question] how to reduce overhead of using jax.numpy instead of numpy
See original GitHub issueHi JAX team,
I don’t know how to properly ask questions so I make an issue for this. In this gist, I try to see the performance of jax.numpy
vs numpy
in CPU with the sum_logistics function (which is used in JAX’s quick start guide). The gist shows that jax.numpy is much slower than numpy in such small task. But it scales well when N changed from 10 to 1000. So I think that I have used JAX in a wrong way.
I’m just new to JAX and eager to learn. Could someone please let me know what else I need to properly use jax.numpy
and improve its speed?
Issue Analytics
- State:
- Created 5 years ago
- Reactions:3
- Comments:7 (5 by maintainers)
Top Results From Across the Web
JAX Frequently Asked Questions (FAQ)
JAX, on the other hand, has several ways to avoid dispatch overhead (e.g. JIT compilation, asynchronous dispatch, batching transforms, etc.), and so reducing...
Read more >An astronomer's introduction to NumPyro
The just-in-time compilation features of JAX can be used to speed up you NumPy computations by removing some Python overhead and by executing...
Read more >How could I speed up this looping code by JAX - Stack Overflow
As for why the code becomes slower when you replace np with jnp here, it's because you're really only using JAX as an...
Read more >Interoperability with NumPy — NumPy v1.24 Manual
Yet, users still want to work with these arrays using the familiar NumPy API and re-use existing code with minimal (ideally zero) porting...
Read more >Why You Should (or Shouldn't) be Using Google's JAX in 2022
1. NumPy on Accelerators - NumPy is one of the fundamental packages for scientific computing with Python, but it is compatible only with...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for bringing this up! See this comment for an answer very similar to the one below.
This kind of behavior is expected: numpy has really low operation dispatch overheads, since it’s had years of expert engineering. JAX’s operation dispatch overheads are high, not for any design reason but because it hasn’t yet been engineered to the same point.
Dispatch costs are independent of the size of the arrays being operated on. So if you’re dispatching a lot of small operations, JAX’s dispatch overhead is going to loom large. But if you scale up the data being operated on, those fixed dispatch costs aren’t as big.
Another way to crush dispatch overheads is to use
@jit
. When you call an@jit
function, you only pay the dispatch cost once, no matter how manyjax.numpy
functions are called inside that@jit
function. Plus, XLA will end-to-end optimize the whole function, including fusing operations together and optimizing memory layouts and use.The example function in your gist is pretty small in the sense that it didn’t call many
jax.numpy
functions, so these dispatch overheads were still large relative to the total time of the computation, and you had to scale up the array sizes to see XLA’s compilation benefits rather than just measuring dispatch overheads. But the ideal case for JAX is when you use@jit
on functions that have many more FLOPs in them. If you@jit
decently sized functions, the dispatch overhead won’t be noticeable.To put a finer point on it: if you dumped the assembly code that XLA is generating for your use case, you’d see that it’s really highly optimized. Your program isn’t spending much time executing that code. Instead, it’s spending time in dispatching overheads.
So the main dispatch-overhead-fighting weapon at your disposal now is
@jit
on bigger functions. We expect dispatch overheads to come down over time, through a few avenues:What do you think?
Thanks @mattjj!