question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Does lax cond short circuit?

See original GitHub issue

Hello! I have a function f that wraps two functions, one of which is very expensive (f_1), the other (f_2) is not (they return the same shaped array). If one of the arguments to f is false, we do not need the expensive function. Ultimately, I wrap this inside a jitted function, so I must use lax.cond to split f into f_1 and f_2. Does this buy me anything, or do both sides of the conditional have to be executed because of the way jax works. Thanks!

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:16 (9 by maintainers)

github_iconTop GitHub Comments

8reactions
pedrofalecommented, Oct 26, 2020

I think the jax.lax.cond API has changed since this issue was first opened and I’m not sure @mattjj’s comments apply in the same way. For example, if I do

import jax

def f(x):
  return jax.lax.cond(x > 0, lambda x: x**2, lambda x: jax.lax.while_loop(lambda x: True, lambda _: _, 0), x)

then doing f(2) will run the infinite loop. How can I avoid that?

4reactions
mattjjcommented, May 16, 2020

One detail to add on: only the operations in each branch that have a data dependence on the explicit branch operands will be delayed; operations with no data dependence on the operands are executed at trace time when not using a jit, and unconditionally when using a jit.

Here’s an example:

@jit
def f(x):
  return lax.cond(x > 0,
                  (), lambda _: np.sin(x),
                  (), lambda _: np.cos(x))

On the current master branch, both np.sin(x) and np.cos(x) will be evaluated on each evaluation of f(x). Another way to put it is that they’ll be hoisted out of the cond entirely.

To ensure only one side is executed per application of f, we’d need to rewrite it as

@jit
def f(x):
  return lax.cond(x > 0,
                  x, lambda x: np.sin(x),
                  x, lambda x: np.cos(x))

This is a weird quirk of our tracing implementation, and we’re working on revising it. Hoping to land a fix in the next couple weeks!

Read more comments on GitHub >

github_iconTop Results From Across the Web

jax.lax.cond - JAX documentation - Read the Docs
operands – Operands (A) input to either branch depending on pred . The type can be a scalar, array, or any pytree (nested...
Read more >
4 sickened by carbon dioxide release at Los Angeles airport
A reported carbon dioxide leak at Los Angeles International Airport has left four workers sick, according to the Los Angeles Fire ...
Read more >
Why would a language NOT use Short-circuit evaluation?
Reasons NOT to use short-circuit evaluation: Because it will behave differently and produce different results if your functions, ...
Read more >
Is the LAX area safe? | Tips for parking at Los Angeles Airport
You can park in most of the areas near LAX Airport. There are even metered parking spaces that you could use if you...
Read more >
4 employees sickened by carbon dioxide fumes in LAX ...
“It was a complete deluge of that system — it doesn't come out as a short sporadic amount of carbon dioxide, but rather...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found