jax.jit recompiles nested jitted functions
See original GitHub issueThis isn’t really a bug and more of a question, I guess, but who knows.
Why is f being compiled twice in this example:
x = onp.random.rand(1024).astype(onp.float32)
x = jax.device_put(x)
def f(x):
print('f x', x)
return np.square(x)
def g(f, x):
print('g f', f)
print('g x', x)
return f(x)
def h(f, x):
print('h f', f)
print('h x', x)
return f(x)
f = jax.jit(f)
g = jax.jit(g, static_argnums=(0,))
h = jax.jit(h, static_argnums=(0,))
f(x) # will trigger compilation of f
f(x) # reuse cache
print('')
g(f, x) # will trigger compilation of g and another compilation of f
g(f, x) # reuse cache
print('')
h(f, x) # will trigger compilation of h, but uses cached f
h(f, x) # reuse cache
f x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>
g f <function jit.<locals>.f_jitted at 0x7faa5bb22ae8>
g x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>
f x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/2)>
h f <function jit.<locals>.f_jitted at 0x7faa5bb22ae8>
h x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>
I guess it has something to do with the level, but I don’t really get why a jitted function is recompiled once it gets called from within another jitted function.
Issue Analytics
- State:
- Created 5 years ago
- Comments:5 (5 by maintainers)
Top Results From Across the Web
[D] JAX: What are the best practices for using jit? - Reddit
There is no issue with nesting jits, it should be fine to over-jit. Jit the highest level function in the call stack that...
Read more >jax.jit - JAX documentation
jax.jit# ... Sets up fun for just-in-time compilation with XLA. ... Function to be jitted. fun should be a pure function, as side-effects...
Read more >How can I redefine a subfunction of a JAX-jitted function?
JAX transforms like JIT only work correctly for pure functions; see JAX ... from jax import jit from functools import partial def bar(x): ......
Read more >JAX is for Joy, AutoDiff, and Xeleration - Jan Ebert
Multiple jax.grad applications can be nested to take higher-order derivatives. ... static_argnums basically tells jax.jit to recompile the JITted function ...
Read more >iree jax - Model Zoo
The IREE JAX API provides a compiler and runtime bridge between JAX and IREE ... new_params # "Kernel" functions are basically equivalent to...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Following up on this, now that I worked a bit more with it: calling jitted functions from within jitted functions does not seem to be a good idea… i.e. not jitting the inner function explicitly (it will still be jitted when the whole outer function is jitted) seems to improve performance… is this a general rule one can keep in mind (i.e. for maximum performance put
jit
only on the most outer function) or are there cases where jitting functions inside other jitted functions is advantageous?Our trace-caching logic is pretty simple: it’s just a
@memoize
decorator on the function that takes a wrapped Python callablefun
and a set of abstract arguments and returns an executable XLA computation. The wrapping offun
just records what transformations have been applied to the underlying Python callable (and any auxiliary information that they need to smuggle out after the function has been called, like the tuple/list/dict tree structure of the output), and so that@memoize
decorator is taking that transformation stack into account too.What’s complicated is the transformations we need to set up to guard against other traces’ tracers hiding in function closures.
We can add a
print(fun)
to the top of the memoizedxla_callable
function to see why we’re getting cache misses running your script. If we run just this part:we see this (ignoring the stuff that prints out due to
device_put
):So the second time we’re seeing
f
, the transform context is pretty different. It’s tricky to unpack the details, but the high-level issue is that the second time we tracef
we don’t know if it closed over values that are traced byg
. If it did, we need to generate different code (because, in effect, the computation carried out byf
has more inputs the second time), even though fromf
’s point of view those are just constants. This is basically like lambda lifting.In this case, there’s special structure here:
f
is a top-level function in the original Python source text, and in particular doesn’t close over any values that could be traced, so we’re safe from this closure issue. Maybe we could detect this special structure (by checking the Python runtime function object and noticing it has an empty closure?) and get a cache hit here.But in general, when a function has a non-empty closure, we can’t tell whether that’s a benign closure (with no hidden traces) or whether that closure contains other tracers (maybe very indirectly, buried in arbitrary Python objects, including closed-over Python function objects) until we actually run the function. And at the point where we call and memoize
xla_callable
, we haven’t actually run the function yet, so we don’t know if we’re safe from nested tracers, and we need to be defensive.I’m inclined to err on the side of simplicity and not try to detect this special closure-free structure until we have a use case that needs it. (However, it’s possible that @jonasrauber has already articulated possible use cases in other issues, and I just haven’t grokked them yet.)
@dougalm did I get that right? Should we consider special handling of empty-closure functions, which might mean a special bind_call that is promised there are no traces in the closure? (There’s a related todo in ad.py in
call_transpose
.)