question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

logical operators with jit?

See original GitHub issue

This feels like it might be a bug in how boolean values are traced, and if not, it’s perhaps a rather subtle sharp bit?

I’ve run in to some trouble trying to jit a function containing jax.lax.cond, and I’ve reduced it to a toy example that I can’t quite wrap my head around (below). I’m it seems like logical operators might be the problem, b/c things work fine with a predicate like e.g. x<0, but not with x<0 and y<0 or not x<0. Based on my admittedly incomplete understanding, it seems like it ought to be possible to determine that both expressions result in the same type (ShapedArray((), np.bool_)?), and thus it seem like if jit+jax.lax.cond works for the former it ought to work for the latter as well? Thanks in advance for any insight into this!

Toy example:

import jax
from jax import jit

@jit
def true_fn(x):
    return x - 5000.0

@jit
def false_fn(x):
    return x + 100.0

@jit
def cond_test_1(x, y):
    return jax.lax.cond(x < 0, true_fn, false_fn, x)

@jit
def cond_test_2(x, y):
    return jax.lax.cond(x < 0 and y < 0, true_fn, false_fn, x)

def cond_test_2_no_jit(x, y):
    return jax.lax.cond(x < 0 and y < 0, true_fn, false_fn, x)

Results from running these functions:


>>> cond_test_1(0.3, 0.1)
DeviceArray(100.3, dtype=float32)

>>> cond_test_2_no_jit(0.3, 0.1)
DeviceArray(100.3, dtype=float32)

>>> cond_test_2(0.3, 0.1)
Traceback (most recent call last):

[[...snip...]]

jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray(bool[]):JaxprTrace(level=-1/1)>
>>> 

PS: I’m really loving JAX, It’s a revelation! Thanks to the team! ❤️

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

2reactions
jakevdpcommented, Jul 14, 2020

The issue here is with Python’s and, which calls bool on its arguments. You should be able to use bitwise operators in its place; for example:

@jit
def cond_test_2(x, y):
    return jax.lax.cond((x < 0) & (y < 0), true_fn, false_fn, x)

(don’t forget the parentheses, becuase & has higher precedance than <)

Note that the same thing is in play when x and y are standard numpy arrays: (x < 0) and (y < 0) will error for arrays with more than one element, while (x < 0) & (y < 0) will return the elementwise boolean output.

0reactions
benschreibercommented, Jul 19, 2020

I’m a newbie and I just ran into the exact same issue. It would be great to see a mention of this in the docs.

Read more comments on GitHub >

github_iconTop Results From Across the Web

jit.op Reference - Max Documentation
The jit.op object applies either a binary operator to two input matrices, or a unary operator to the left input matrix. A different...
Read more >
JIT Logic Operations Management Assignment Help
JIT Logic Operations Management Homework and Assignment Help, Homework and Project Assistance JIT Logic JIT (just-in-time) is an integrated et of activities ...
Read more >
JIT Forecasting and Master Scheduling
Duration: Self-study or classroom training Description: The 50-Minute Manager Series was designed to cover critical business and ...
Read more >
JIT Search Help - JustInTimeMedicine
Boolean Operators : Maximizing AND, OR, and NOT to Narrow or Broaden Your Search · JIT's search engine currently defaults to OR (but...
Read more >
Expressions — libgccjit 10.1.0 ( ) documentation
The most concise way to spell them is with overloaded operators: ... gccjit::rvalue gccjit::context :: new_logical_and (gccjit::type result_type, ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found