question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[question] how to reduce overhead of using jax.numpy instead of numpy

See original GitHub issue

Hi JAX team,

I don’t know how to properly ask questions so I make an issue for this. In this gist, I try to see the performance of jax.numpy vs numpy in CPU with the sum_logistics function (which is used in JAX’s quick start guide). The gist shows that jax.numpy is much slower than numpy in such small task. But it scales well when N changed from 10 to 1000. So I think that I have used JAX in a wrong way.

I’m just new to JAX and eager to learn. Could someone please let me know what else I need to properly use jax.numpy and improve its speed?

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Reactions:3
  • Comments:7 (5 by maintainers)

github_iconTop GitHub Comments

8reactions
mattjjcommented, Feb 21, 2019

Thanks for bringing this up! See this comment for an answer very similar to the one below.

This kind of behavior is expected: numpy has really low operation dispatch overheads, since it’s had years of expert engineering. JAX’s operation dispatch overheads are high, not for any design reason but because it hasn’t yet been engineered to the same point.

Dispatch costs are independent of the size of the arrays being operated on. So if you’re dispatching a lot of small operations, JAX’s dispatch overhead is going to loom large. But if you scale up the data being operated on, those fixed dispatch costs aren’t as big.

Another way to crush dispatch overheads is to use @jit. When you call an @jit function, you only pay the dispatch cost once, no matter how many jax.numpy functions are called inside that @jit function. Plus, XLA will end-to-end optimize the whole function, including fusing operations together and optimizing memory layouts and use.

The example function in your gist is pretty small in the sense that it didn’t call many jax.numpy functions, so these dispatch overheads were still large relative to the total time of the computation, and you had to scale up the array sizes to see XLA’s compilation benefits rather than just measuring dispatch overheads. But the ideal case for JAX is when you use @jit on functions that have many more FLOPs in them. If you @jit decently sized functions, the dispatch overhead won’t be noticeable.

To put a finer point on it: if you dumped the assembly code that XLA is generating for your use case, you’d see that it’s really highly optimized. Your program isn’t spending much time executing that code. Instead, it’s spending time in dispatching overheads.

So the main dispatch-overhead-fighting weapon at your disposal now is @jit on bigger functions. We expect dispatch overheads to come down over time, through a few avenues:

  1. on the CPU backend, we currently copy back and forth between JAX’s DeviceArrays and numpy ndarrays but I think we can share memory between the two and avoid all those copies;
  2. on the GPU backend, we’re using a relatively slow memory allocation strategy (see #417);
  3. we can probably optimize our XLA client in miscellaneous other ways.

What do you think?

0reactions
twieckicommented, Apr 17, 2019

Thanks @mattjj!

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
JAX, on the other hand, has several ways to avoid dispatch overhead (e.g. JIT compilation, asynchronous dispatch, batching transforms, etc.), and so reducing...
Read more >
An astronomer's introduction to NumPyro
The just-in-time compilation features of JAX can be used to speed up you NumPy computations by removing some Python overhead and by executing...
Read more >
How could I speed up this looping code by JAX - Stack Overflow
As for why the code becomes slower when you replace np with jnp here, it's because you're really only using JAX as an...
Read more >
Interoperability with NumPy — NumPy v1.24 Manual
Yet, users still want to work with these arrays using the familiar NumPy API and re-use existing code with minimal (ideally zero) porting...
Read more >
Why You Should (or Shouldn't) be Using Google's JAX in 2022
1. NumPy on Accelerators - NumPy is one of the fundamental packages for scientific computing with Python, but it is compatible only with...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found