inner `jit` functions are re-traced (and re-compiled)
See original GitHub issueIt seems that jit
decorators are ignored for nested functions. Consider the following example (simplified from an actual program, where inner_fn
is slightly more complex):
import jax
import jax.numpy as jnp
from functools import partial
@partial(jax.jit, inline=True)
def inner_fn(state):
print("entering inner_fn now")
return 2*state
@jax.jit
def outer_fn(x):
print("entering outer_fn now")
old_x = x
for _ in range(10):
x = inner_fn(x)
x = x + old_x
return x
state = jnp.arange(5, dtype=jnp.uint32)
outer_fn(state).block_until_ready()
Expected behavior: As both functions are decorated with jit, inner_fn
should be compiled once and the result re-used by the plain for loop inside of outer_fn
. Consequently, the output would show only one emission of “entering inner_fn now”.
Actual behavior: The jit decorator for inner_fn
seems to be ignored: “entering inner_fn now” is printed 10 times and setting env var JAX_LOG_COMPILES=1
only prints one line for Compiling outer_fn
. In short, we observe the same behavior as if the jit decorator were absent for inner_fn
.
Why a bug?: The current behavior negates any of the caching done for jitted functions, resulting in repeated tracing of a function with identical tracers and thus inflated compilation times and can easily surprise the user.
Other considerations:
- Using
fori_loop
: Using afori_loop
results ininner_fn
being visited only once, so that will mitigate the issue for this particular example. However, we were actually having troubles with it as it prevents unrolling the loop and ourinner_fn
was small enough so that we actually were seeing performance hits from cuda kernel launches forinner_fn
in this case. - No
inline=True
for jit decorator ofinner_fn
: In this case the compilation will compile 10 separate sub-jaxpressions forinner_fn
that are not inlined but all invoked viaxla_call
(when inspecting this viajax.make_jaxpr
). We see that the jit decorator is therefore not completely without effect, but this is arguably the worst case (separatexla_call
overhead as in thefori_loop
AND multiple compilation ofinner_fn
).
Please:
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
- If applicable, include full error messages/tracebacks.
Issue Analytics
- State:
- Created 2 years ago
- Comments:7 (7 by maintainers)
Top GitHub Comments
A concrete next step on our end is to add thorough logging so that these times are super easy to inspect.
Alright, I tried to boil it down to the essence in the following. So basically I have this piece of code which repeatedly applies a vector-valued-function (
col_fn
) on all columns and (some form of) diagonals of a (4x4) matrix (this realised byinner_fn
and the repeated application as a loop inouter_fn
). My basic approach forinner_fn
was to firstvmap
the functioncol_fn
over columns, then reorganise the matrix so that the diagonals I’m interested in are put into columns and apply the vmappedcol_fn
again. This remapping is what seems to be causing problems for me. Code (col_fn
is just a stand-in right now, with a loop to adjust its complexity (in terms of simple operations it applies)):As you can see, I tried different ways of reorganising the matrix and they result in different fusing outcomes, which I investigated by looking at the outputs produced when setting the
--xla_dump_hlo_as_text
XLA flag.gather
operations in HLO (i.e., a fusion will always end with agather
op). This roughly merges 2 iterations ofinner_fn
, so for 7 iterations I end up with 4 fusions: The first covering the first 1.5 iterations (stopping at the reorganising in the 2nd iterations), the second fusion covers the second half of the 2nd iteration, the entire 3rd and the first half of the 4th, and so on). This seems invariant of the amount of operations incol_fn
or the number of iterations ofinner_fn
.inner_fn
, depending entirely on the complexity/number of iterations incol_fn
: In my tests, I get a single fusion for less than 13 iterations incol_fn
and one fusion perinner_fn
invocation for more. The splits again occur at thegather
ops, so the first fusion covers the first half of the first invocation ofinner_fn
, each following fusion cover a second half and the following first half of the next iteration.gather
ops emitted and the result is a single fusion in all cases, no matter the complexity ofcol_fn
or the number of iterations ofinner_fn
.Further: In reality,
col_fn
is a bit more involved than shown above and requires interdependent updates to the components of the received vector. Something akin to this (with some randomly chosen elementary operations)I haven’t been able to do this without at some point using a concatenating operation to produce the output vector (or using an index_update), which results in even more splits of the fusions in the output HLO dump.
Other things I tried:
col_fn
receives the full matrix and a list of indices instead of a vector slice of the matrix, to avoid having to reorganise the matrix. However, this is merely moving thegather
from the reorganising ininner_fn
to the then required index lookup at the start ofcol_fn
…reduce
ops that replace thegather
ops (+ additionaladd
kernels for each iteration that are used by thereduce
ops).