logical operators with jit?
See original GitHub issueThis feels like it might be a bug in how boolean values are traced, and if not, it’s perhaps a rather subtle sharp bit?
I’ve run in to some trouble trying to jit a function containing jax.lax.cond
, and I’ve reduced it to a toy example that I can’t quite wrap my head around (below). I’m it seems like logical operators might be the problem, b/c things work fine with a predicate like e.g. x<0
, but not with x<0 and y<0
or not x<0
. Based on my admittedly incomplete understanding, it seems like it ought to be possible to determine that both expressions result in the same type (ShapedArray((), np.bool_)
?), and thus it seem like if jit
+jax.lax.cond
works for the former it ought to work for the latter as well? Thanks in advance for any insight into this!
Toy example:
import jax
from jax import jit
@jit
def true_fn(x):
return x - 5000.0
@jit
def false_fn(x):
return x + 100.0
@jit
def cond_test_1(x, y):
return jax.lax.cond(x < 0, true_fn, false_fn, x)
@jit
def cond_test_2(x, y):
return jax.lax.cond(x < 0 and y < 0, true_fn, false_fn, x)
def cond_test_2_no_jit(x, y):
return jax.lax.cond(x < 0 and y < 0, true_fn, false_fn, x)
Results from running these functions:
>>> cond_test_1(0.3, 0.1)
DeviceArray(100.3, dtype=float32)
>>> cond_test_2_no_jit(0.3, 0.1)
DeviceArray(100.3, dtype=float32)
>>> cond_test_2(0.3, 0.1)
Traceback (most recent call last):
[[...snip...]]
jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected (in `bool`).
Use transformation parameters such as `static_argnums` for `jit` to avoid tracing input values.
See `https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error`.
Encountered value: Traced<ShapedArray(bool[]):JaxprTrace(level=-1/1)>
>>>
PS: I’m really loving JAX, It’s a revelation! Thanks to the team! ❤️
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (1 by maintainers)
The issue here is with Python’s
and
, which callsbool
on its arguments. You should be able to use bitwise operators in its place; for example:(don’t forget the parentheses, becuase
&
has higher precedance than<
)Note that the same thing is in play when
x
andy
are standard numpy arrays:(x < 0) and (y < 0)
will error for arrays with more than one element, while(x < 0) & (y < 0)
will return the elementwise boolean output.I’m a newbie and I just ran into the exact same issue. It would be great to see a mention of this in the docs.