Does lax cond short circuit?
See original GitHub issueHello! I have a function f
that wraps two functions, one of which is very expensive (f_1
), the other (f_2
) is not (they return the same shaped array). If one of the arguments to f
is false, we do not need the expensive function. Ultimately, I wrap this inside a jitted function, so I must use lax.cond to split f
into f_1
and f_2
. Does this buy me anything, or do both sides of the conditional have to be executed because of the way jax works. Thanks!
Issue Analytics
- State:
- Created 3 years ago
- Comments:16 (9 by maintainers)
Top Results From Across the Web
jax.lax.cond - JAX documentation - Read the Docs
operands – Operands (A) input to either branch depending on pred . The type can be a scalar, array, or any pytree (nested...
Read more >4 sickened by carbon dioxide release at Los Angeles airport
A reported carbon dioxide leak at Los Angeles International Airport has left four workers sick, according to the Los Angeles Fire ...
Read more >Why would a language NOT use Short-circuit evaluation?
Reasons NOT to use short-circuit evaluation: Because it will behave differently and produce different results if your functions, ...
Read more >Is the LAX area safe? | Tips for parking at Los Angeles Airport
You can park in most of the areas near LAX Airport. There are even metered parking spaces that you could use if you...
Read more >4 employees sickened by carbon dioxide fumes in LAX ...
“It was a complete deluge of that system — it doesn't come out as a short sporadic amount of carbon dioxide, but rather...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I think the
jax.lax.cond
API has changed since this issue was first opened and I’m not sure @mattjj’s comments apply in the same way. For example, if I dothen doing
f(2)
will run the infinite loop. How can I avoid that?One detail to add on: only the operations in each branch that have a data dependence on the explicit branch operands will be delayed; operations with no data dependence on the operands are executed at trace time when not using a
jit
, and unconditionally when using ajit
.Here’s an example:
On the current master branch, both
np.sin(x)
andnp.cos(x)
will be evaluated on each evaluation off(x)
. Another way to put it is that they’ll be hoisted out of thecond
entirely.To ensure only one side is executed per application of
f
, we’d need to rewrite it asThis is a weird quirk of our tracing implementation, and we’re working on revising it. Hoping to land a fix in the next couple weeks!