A case in which `jax.jit` changes results, prevents backprop, and inhibits optimisations
See original GitHub issueThe summary is that I’ve found an (edge) case in which:
- adding a
jax.jit
decorator can change the functional behaviour of the program; - adding a
jax.jit
decorator can prohibit reverse-mode autodifferentiation. - adding a
jax.jit
decorator can inhibit the use of compile-time optimisations;
In all cases it is due to the same single root cause: a nested jax.jit
unnecessarily converts concrete values to tracers.
First a MWE of the root cause:
import jax
@jax.jit
def f(x):
print(x)
@jax.jit
def g():
f(1)
g()
This program prints Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/2)>
: the 1
is unnecessarily converted into a tracer.
At first glance that probably seems reasonable. Why is this undesirable? Consider the following case.
import jax
import jax.lax as lax
@jax.jit
def f(init, lower):
def _body_fun(_, _y):
return _y + 1
return lax.fori_loop(lower, 1, _body_fun, init)
@jax.jit
def g():
return jax.grad(f)(0.0, 0)
g()
adding a jax.jit
can change the functional behaviour of the program
If f
is decorated with jax.jit
then the above program crashes.
If f
is undecorated then the program runs.
adding a jax.jit
can inhibit reverse-mode autodifferentiation
(Same as above:
If f
is decorated with jax.jit
then the above program crashes.
If f
is undecorated then the program runs.)
adding a jax.jit
can inhibit the use of compile-time optimisations
As a library author interested in optimising compile time and run times, then I’d like to specialise behaviour based on compile-time values.
In particular I have a case in which if a value is known at compile time then I can produce code that is efficient to compile. If it is only known at run time then the extra generality requires a more complicated program – that produces the same functional results – but is several orders of magnitude slower to compile.
(So morally speaking something a little similar to lax.fori_loop
in native JAX.)
Resolution: I think the resolution should be that within the dynamic context of a jax.jit
decorator, then all other jax.jit
decorators should be disabled.
C.f. #7155: sub-jit-functions are retraced within the dynamic context of a jax.jit
decorator.
Overall, sub-jit-functions currently enter a weird state of being only “partially disabled” within the dynamic context of a jax.jit
decorator: they are retraced, but values are still unnecessarily promoted to tracers.
Issue Analytics
- State:
- Created 2 years ago
- Comments:19 (12 by maintainers)
If this is intended then it’s an “as intended” that only has downsides: having a sub-jit is only ever a Bad Thing™. We still have to re-trace, but now also have these cast-to-tracers. It would make more sense to either (a) fully disable the sub-jits (and thus improve compile-time evaluation), or (b) fully enable the sub-jits (and then avoid re-tracing, re-compiling etc.)
@patrick-kidger Ahh, my bad, sorry for being slow on your points. Just realized that sometimes
lax.fori_loop
is really just alax.while_loop
, hence allowing tracedlower
andupper
. Somehow I only think offori_loop
in terms ofscan
. Then I agree typingfori_loop
is not as easy as setting static arguments and