change `scan` and `while_loop` impls to Python versions
See original GitHub issueCurrently lax.while_loop
(and as a consequence lax.fori_loop
and lax.scan
) incur compilation time every time they’re evaluated in op-by-op mode, making them seem slow to execute without being inside an @jit
(since the @jit
will handle the caching). We should remedy that, and make the op-by-op impl rules fast. (Separately, while we’re looking at this code, we might be able to replace _while_loop_translation_rule
by calling xla.lower_fun
on its new impl.)
Issue Analytics
- State:
- Created 4 years ago
- Reactions:6
- Comments:8 (8 by maintainers)
Top Results From Across the Web
How to change condition from a while loop - python
When I execute this code, the While loop ignores the completed condition and keeps asking for adding more books even though the user...
Read more >Look Ma, No For-Loops: Array Programming With NumPy
In this tutorial, you'll see step by step how to take advantage of vectorization and broadcasting, so that you can use NumPy to...
Read more >2. Lexical analysis — Python 3.11.1 documentation
A Python program is read by a parser. Input to the parser is a stream of tokens, generated by the lexical analyzer. This...
Read more >How To Construct While Loops in Python 3 - DigitalOcean
One way to repeat similar tasks is through using loops. We'll be covering Python's while loop in this tutorial. A while loop implements...
Read more >while (Boolean condition) statement;
A compound statement is a bunch of statements enclosed by curly braces! } • A Boolean condition is either true or false. •...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
We haven’t forgotten about this! We’re working on a rewrite to remove tuples from the jaxpr language (it’s an internals-only change, it won’t change the API anywhere) and as part of that change we’re revising the control flow pretty heavily (because it uses a lot of tuples). Once that other change lands we should fix these performance bugs once and for all!
Hi @mattjj , previously, we observed that
jit(lax.fori_loop, ...)
is faster thanlax.fori_loop
depending onbody_fn
. I think that the following code (a little verbose to make it explicit what I want to illustrate) is good for benchmark because it illustrates some problems with the latest PR:The target is to use various versions of
optimize
onf
. Here are the result with latest PR:While before the PR, I get
I can observe that the behaviour of
lax.fori_loop
outsidejit
has been changed with the last PR, and seems worse than before. (btw, while playing with some benchmark codes, I observed that in recent versions of jax (e.g. v0.1.39),lax.fori_loop
seems a bit faster thanjit(lax.fori_loop, ...)
and I am unable to make an example to showjit(lax.fori_loop, ...)
is faster thanlax.fori_loop
).