Do JAX jit'd Python loops run faster than jit'd LAX loop constructs?
See original GitHub issuePlease see the GIST Minimal repro where I tried to implement a recurrent computation (an echostate network) both in JAX and LAX. Using the lax.fori_loop resulted in a roughly 3x slow-down over jax jit’d naive python for loop.
JAX speed
Params seed 100001
JAX run 0.0160 sec
LAX speed
LAX run 0.0418 sec
This is not blocking me, but I was surprised by it and I cannot find anything I did wrong, though I may have misused the APIs in some way. JAX versions are listed as comment in the gist.
Issue Analytics
- State:
- Created 5 years ago
- Comments:8 (8 by maintainers)
Top Results From Across the Web
Sabrina J. Mielke on Twitter: "@sschoenholz @shoyer ...
Playing with jit in JAX, sampling a length-500 seq from an LSTM-LM (avg of 10x): ... and regular Python loop constructs get you...
Read more >Loops in Python – comparison and performance - Duomly blog
This article compares the performance of several approaches when summing two sequences element-wise with different Python loops.
Read more >jax.lax.while_loop - JAX documentation - Read the Docs
Another difference from using Python-native loop constructs is that while_loop is not reverse-mode differentiable because XLA computations require static bounds ...
Read more >How to reduce JAX compile time when using for loop?
JAX's JIT compiler flattens all Python loops. To see what I mean, take a look at this simple function run through jax.make_jaxpr ,...
Read more >Making a for-loop more efficient - numpyro
Hi Everyone, While sampling with NUTS, I'm solving a system of differential algebraic equations using jax's odeint().
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for bringing up this issue! Hope you don’t mind that I tweaked the title to be a bit more explicit.
My current best guess is this might be a potential performance improvement for XLA to make. Some thoughts below.
When you use a regular Python loop construct under JAX’s
@jit
, JAX (like Autograd) doesn’t even see the Python loop and instead just traces out the unrolled computation; as a consequence, the program that gets staged out to XLA is also unrolled. When you instead use alax.while_loop
(orlax.fori_loop
), the JAX tracer sees that loop as a primitive, and stages out a loop construct in the XLA program.In the latter case, XLA gets strictly more information: it sees that there’s a loop construct, and it has the option to unroll it when it’s statically unrollable. (TODO for us: check that JAX is lowering the loop in a way that keeps static-unrollability transparent for XLA, since we do lift constants into the loop carry tuple.) So it should be able to generate code that is at least as good.
From XLA’s perspective, there are the usual tradeoffs with unrolling here: unrolling a loop could increase the code size (which might increase execution time if the code has to be loaded onto the device for each kernel launch), but enables some more optimizations and involves fewer branches. It may be that “branches” are expensive for the GPU backend because the loop condition might need to be pulled back to the host and checked there on each iteration, meaning more synchronizations than would be necessary in the unrolled case. Still, it seems that in principle XLA could do some unrolling or partial unrolling to mitigate that effect in cases like this one.
Separate from questions of execution time, one of the main reasons to use loop constructs now is to reduce compile times: unrolling big loops can mean staging out a large program to XLA, and since XLA does a lot of optimizations, that can mean a lot of redundant work.
So as a general rule of thumb, we can think of using loop constructs like
lax.while_loop
andlax.fori_loop
as tools to reduce compile times, often by a huge amount, but that could sometimes result in reduced execution performance. In principle those reductions in execution performance could be minimal, but right now there are probably cases where XLA loops aren’t nearly as fast to execute as the unrolled code.This is now fixed, but to get the fix, you either need to rebuild jaxlib from source or to wait until we push new binary wheels to PyPI (probably later this week).