lax.cond does not exist, contrary to the README
See original GitHub issueThe README says
If you want compiled control flow, use structured control flow primitives like lax.cond and lax.while.
But lax.cond does not actually exist, and lax.while only exists in the form of lax._while_loop.
-
Is
lax.condalready implemented somewhere, just not part of the master branch? I don’t need gradient support forlax.cond(which is tracked by PR #83), just jit-compilablecondwould be a huge gain for me (currenty looking into whether I might be able to contribute this if it doesn’t exist yet) -
lax.whilewill probably never exist becausewhileis a reserved keyword; so this should probably be renamed tolax.while_loopand can probably already be made available by renaminglax._while_loop, or what’s the problem with the_while_loopimplementation (fori_loopandforeach_loopare not prefixed with an underscore and use_while_loop, so it should be fully functional)?
Issue Analytics
- State:
- Created 5 years ago
- Comments:5 (4 by maintainers)

Top Related StackOverflow Question
Early versions of
lax.cond(circa early 2017) lowered into while loops (though I think we needed 2 in general, I can’t remember why). But were totally broken from a tracing-composability perspective, which is why we don’t have them now.The main reason we haven’t exposed a
lax.condyet is essentially the same as why #331 and #207 are outstanding issues, namely that we want to handle closures and arbitrary composability correctly. @dougalm designed the core system to handle these issues, and actually it enables two ways of handling higher-order functions likelax.condandlax.while, which we can call “the hard way” and “the easy way”. We recently decided that the practical benefits of the hard way over the easy way are pretty miniscule (though academically interesting), and so @dougalm started going “the easy way” in #334.#334 doesn’t add
lax.condbut it paves the way for doing it (along with a differentiablelax.mapandlax.scan).@jekbradbury
np.where(b, x, y)is also possible (akalax.selectwith different broadcasting semantics).