question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

jax.jit recompiles nested jitted functions

See original GitHub issue

This isn’t really a bug and more of a question, I guess, but who knows.

Why is f being compiled twice in this example:

x = onp.random.rand(1024).astype(onp.float32)
x = jax.device_put(x)

def f(x):
    print('f x', x)
    return np.square(x)

def g(f, x):
    print('g f', f)
    print('g x', x)
    return f(x)

def h(f, x):
    print('h f', f)
    print('h x', x)
    return f(x)

f = jax.jit(f)
g = jax.jit(g, static_argnums=(0,))
h = jax.jit(h, static_argnums=(0,))

f(x)  # will trigger compilation of f
f(x)  # reuse cache
print('')
g(f, x)  # will trigger compilation of g and another compilation of f
g(f, x)  # reuse cache
print('')
h(f, x)  # will trigger compilation of h, but uses cached f
h(f, x)  # reuse cache
f x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>

g f <function jit.<locals>.f_jitted at 0x7faa5bb22ae8>
g x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>
f x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/2)>

h f <function jit.<locals>.f_jitted at 0x7faa5bb22ae8>
h x Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>

I guess it has something to do with the level, but I don’t really get why a jitted function is recompiled once it gets called from within another jitted function.

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

18reactions
jonasraubercommented, Feb 20, 2019

Following up on this, now that I worked a bit more with it: calling jitted functions from within jitted functions does not seem to be a good idea… i.e. not jitting the inner function explicitly (it will still be jitted when the whole outer function is jitted) seems to improve performance… is this a general rule one can keep in mind (i.e. for maximum performance put jit only on the most outer function) or are there cases where jitting functions inside other jitted functions is advantageous?

1reaction
mattjjcommented, Feb 6, 2019

but I don’t really get why a jitted function is recompiled once it gets called from within another jitted function.

Our trace-caching logic is pretty simple: it’s just a @memoize decorator on the function that takes a wrapped Python callable fun and a set of abstract arguments and returns an executable XLA computation. The wrapping of fun just records what transformations have been applied to the underlying Python callable (and any auxiliary information that they need to smuggle out after the function has been called, like the tuple/list/dict tree structure of the output), and so that @memoize decorator is taking that transformation stack into account too.

What’s complicated is the transformations we need to set up to guard against other traces’ tracers hiding in function closures.

We can add a print(fun) to the top of the memoized xla_callable function to see why we’re getting cache misses running your script. If we run just this part:

f(x)
g(f, x)

we see this (ignoring the stuff that prints out due to device_put):

Wrapped function:
0   : flatten_fun   ((*,),)
1   : process_env_traces   (xla_call, -1)
2   : pytree_fun_to_jaxtupletree_fun   ((*,),)
3   : argnums_partial_   ((0,), (None,))
Core: f

('f x', Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>)
Wrapped function:
0   : flatten_fun   ((*,),)
1   : process_env_traces   (xla_call, -1)
2   : pytree_fun_to_jaxtupletree_fun   ((*,),)
3   : argnums_partial_   ((1,), (<jax.util.WrapHashably object at 0x7f9d54442990>, None))
Core: g

('g f', <function jit(f) at 0x7f9d54457848>)
('g x', Traced<ShapedArray(float32[1024]):JaxprTrace(level=-1/1)>)
Wrapped function:
0   : flatten_fun   ((JTupleTreeDef(child_specs=()),),)
1   : process_env_traces   (xla_call, -2)
2   : partial_eval_wrapper   ((ShapedArray(float32[1024]),),)
3   : trace_to_subjaxpr   (MasterTrace(-1,JaxprTrace),)
4   : process_env_traces   (xla_call, -1)
5   : pytree_fun_to_jaxtupletree_fun   ((*,),)
6   : argnums_partial_   ((0,), (None,))
Core: f

So the second time we’re seeing f, the transform context is pretty different. It’s tricky to unpack the details, but the high-level issue is that the second time we trace f we don’t know if it closed over values that are traced by g. If it did, we need to generate different code (because, in effect, the computation carried out by f has more inputs the second time), even though from f’s point of view those are just constants. This is basically like lambda lifting.

In this case, there’s special structure here: f is a top-level function in the original Python source text, and in particular doesn’t close over any values that could be traced, so we’re safe from this closure issue. Maybe we could detect this special structure (by checking the Python runtime function object and noticing it has an empty closure?) and get a cache hit here.

But in general, when a function has a non-empty closure, we can’t tell whether that’s a benign closure (with no hidden traces) or whether that closure contains other tracers (maybe very indirectly, buried in arbitrary Python objects, including closed-over Python function objects) until we actually run the function. And at the point where we call and memoize xla_callable, we haven’t actually run the function yet, so we don’t know if we’re safe from nested tracers, and we need to be defensive.

I’m inclined to err on the side of simplicity and not try to detect this special closure-free structure until we have a use case that needs it. (However, it’s possible that @jonasrauber has already articulated possible use cases in other issues, and I just haven’t grokked them yet.)

@dougalm did I get that right? Should we consider special handling of empty-closure functions, which might mean a special bind_call that is promised there are no traces in the closure? (There’s a related todo in ad.py in call_transpose.)

Read more comments on GitHub >

github_iconTop Results From Across the Web

[D] JAX: What are the best practices for using jit? - Reddit
There is no issue with nesting jits, it should be fine to over-jit. Jit the highest level function in the call stack that...
Read more >
jax.jit - JAX documentation
jax.jit# ... Sets up fun for just-in-time compilation with XLA. ... Function to be jitted. fun should be a pure function, as side-effects...
Read more >
How can I redefine a subfunction of a JAX-jitted function?
JAX transforms like JIT only work correctly for pure functions; see JAX ... from jax import jit from functools import partial def bar(x): ......
Read more >
JAX is for Joy, AutoDiff, and Xeleration - Jan Ebert
Multiple jax.grad applications can be nested to take higher-order derivatives. ... static_argnums basically tells jax.jit to recompile the JITted function ...
Read more >
iree jax - Model Zoo
The IREE JAX API provides a compiler and runtime bridge between JAX and IREE ... new_params # "Kernel" functions are basically equivalent to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found