question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

A case in which `jax.jit` changes results, prevents backprop, and inhibits optimisations

See original GitHub issue

The summary is that I’ve found an (edge) case in which:

  • adding a jax.jit decorator can change the functional behaviour of the program;
  • adding a jax.jit decorator can prohibit reverse-mode autodifferentiation.
  • adding a jax.jit decorator can inhibit the use of compile-time optimisations;

In all cases it is due to the same single root cause: a nested jax.jit unnecessarily converts concrete values to tracers.


First a MWE of the root cause:

import jax

@jax.jit
def f(x):
    print(x)

@jax.jit
def g():
    f(1)

g()

This program prints Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/2)>: the 1 is unnecessarily converted into a tracer.

At first glance that probably seems reasonable. Why is this undesirable? Consider the following case.

import jax
import jax.lax as lax

@jax.jit
def f(init, lower):
    def _body_fun(_, _y):
        return _y + 1
    return lax.fori_loop(lower, 1, _body_fun, init)

@jax.jit
def g():
    return jax.grad(f)(0.0, 0)

g()

adding a jax.jit can change the functional behaviour of the program

If f is decorated with jax.jit then the above program crashes. If f is undecorated then the program runs.

adding a jax.jit can inhibit reverse-mode autodifferentiation

(Same as above: If f is decorated with jax.jit then the above program crashes. If f is undecorated then the program runs.)

adding a jax.jit can inhibit the use of compile-time optimisations

As a library author interested in optimising compile time and run times, then I’d like to specialise behaviour based on compile-time values.

In particular I have a case in which if a value is known at compile time then I can produce code that is efficient to compile. If it is only known at run time then the extra generality requires a more complicated program – that produces the same functional results – but is several orders of magnitude slower to compile.

(So morally speaking something a little similar to lax.fori_loop in native JAX.)


Resolution: I think the resolution should be that within the dynamic context of a jax.jit decorator, then all other jax.jit decorators should be disabled.


C.f. #7155: sub-jit-functions are retraced within the dynamic context of a jax.jit decorator.

Overall, sub-jit-functions currently enter a weird state of being only “partially disabled” within the dynamic context of a jax.jit decorator: they are retraced, but values are still unnecessarily promoted to tracers.

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:19 (12 by maintainers)

github_iconTop GitHub Comments

3reactions
patrick-kidgercommented, Jan 24, 2022

If this is intended then it’s an “as intended” that only has downsides: having a sub-jit is only ever a Bad Thing™. We still have to re-trace, but now also have these cast-to-tracers. It would make more sense to either (a) fully disable the sub-jits (and thus improve compile-time evaluation), or (b) fully enable the sub-jits (and then avoid re-tracing, re-compiling etc.)

2reactions
soraroscommented, Jan 27, 2022

@patrick-kidger Ahh, my bad, sorry for being slow on your points. Just realized that sometimes lax.fori_loop is really just a lax.while_loop, hence allowing traced lower and upper. Somehow I only think of fori_loop in terms of scan. Then I agree typing fori_loop is not as easy as setting static arguments and

As it stands there is no way to wrap lax.fori_loop into a jax.jit, because this kills off the possibility of multiple dispatch; each argument has to be either traced or static.

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
In this case, JIT compilation produces a more accurate floating point approximation of the real result. Unfortunately the full list of XLA's algebraic ......
Read more >
Dr.Jit: A Just-In-Time Compiler for Differentiable Rendering
Differentiating a simulation changes the underlying computation, ... which causes further overheads and inhibits optimization. ... JAX [Bradbury et al.
Read more >
Principled Optimization of Dynamic Neural Networks
Deep learning has often relied on backpropagation or simple AD algorithms defined over computation graphs. The lack of support for a richer programming...
Read more >
Software
Software for Gaussian processes (GPs) have really been improving for quite a while now. It is now a lot easier to not only...
Read more >
ecprice › Public › wordlist.ranked – MIT
... latest road gift ca question changes night hard texas pay four poker status browse issue range building seller court february always result...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found