question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Do JAX jit'd Python loops run faster than jit'd LAX loop constructs?

See original GitHub issue

Please see the GIST Minimal repro where I tried to implement a recurrent computation (an echostate network) both in JAX and LAX. Using the lax.fori_loop resulted in a roughly 3x slow-down over jax jit’d naive python for loop.

JAX speed
Params seed 100001
JAX run 0.0160 sec

LAX speed
LAX run 0.0418 sec

This is not blocking me, but I was surprised by it and I cannot find anything I did wrong, though I may have misused the APIs in some way. JAX versions are listed as comment in the gist.

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:8 (8 by maintainers)

github_iconTop GitHub Comments

2reactions
mattjjcommented, Feb 18, 2019

Thanks for bringing up this issue! Hope you don’t mind that I tweaked the title to be a bit more explicit.

My current best guess is this might be a potential performance improvement for XLA to make. Some thoughts below.

When you use a regular Python loop construct under JAX’s @jit, JAX (like Autograd) doesn’t even see the Python loop and instead just traces out the unrolled computation; as a consequence, the program that gets staged out to XLA is also unrolled. When you instead use a lax.while_loop (or lax.fori_loop), the JAX tracer sees that loop as a primitive, and stages out a loop construct in the XLA program.

In the latter case, XLA gets strictly more information: it sees that there’s a loop construct, and it has the option to unroll it when it’s statically unrollable. (TODO for us: check that JAX is lowering the loop in a way that keeps static-unrollability transparent for XLA, since we do lift constants into the loop carry tuple.) So it should be able to generate code that is at least as good.

From XLA’s perspective, there are the usual tradeoffs with unrolling here: unrolling a loop could increase the code size (which might increase execution time if the code has to be loaded onto the device for each kernel launch), but enables some more optimizations and involves fewer branches. It may be that “branches” are expensive for the GPU backend because the loop condition might need to be pulled back to the host and checked there on each iteration, meaning more synchronizations than would be necessary in the unrolled case. Still, it seems that in principle XLA could do some unrolling or partial unrolling to mitigate that effect in cases like this one.

Separate from questions of execution time, one of the main reasons to use loop constructs now is to reduce compile times: unrolling big loops can mean staging out a large program to XLA, and since XLA does a lot of optimizations, that can mean a lot of redundant work.

So as a general rule of thumb, we can think of using loop constructs like lax.while_loop and lax.fori_loop as tools to reduce compile times, often by a huge amount, but that could sometimes result in reduced execution performance. In principle those reductions in execution performance could be minimal, but right now there are probably cases where XLA loops aren’t nearly as fast to execute as the unrolled code.

0reactions
hawkinspcommented, Feb 26, 2019

This is now fixed, but to get the fix, you either need to rebuild jaxlib from source or to wait until we push new binary wheels to PyPI (probably later this week).

Read more comments on GitHub >

github_iconTop Results From Across the Web

Sabrina J. Mielke on Twitter: "@sschoenholz @shoyer ...
Playing with jit in JAX, sampling a length-500 seq from an LSTM-LM (avg of 10x): ... and regular Python loop constructs get you...
Read more >
Loops in Python – comparison and performance - Duomly blog
This article compares the performance of several approaches when summing two sequences element-wise with different Python loops.
Read more >
jax.lax.while_loop - JAX documentation - Read the Docs
Another difference from using Python-native loop constructs is that while_loop is not reverse-mode differentiable because XLA computations require static bounds ...
Read more >
How to reduce JAX compile time when using for loop?
JAX's JIT compiler flattens all Python loops. To see what I mean, take a look at this simple function run through jax.make_jaxpr ,...
Read more >
Making a for-loop more efficient - numpyro
Hi Everyone, While sampling with NUTS, I'm solving a system of differential algebraic equations using jax's odeint().
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found