Using `jax.jit` inside a function decorated by `jax.checkpoint` causes recompilation every time
See original GitHub issueUsing a jitted function inside a function decorated by jax.checkpoint
causes a lot of extra compilations even if the arguments still have the same shape. Calculating the gradient for such a function causes a memory leak in the long rung since all the compiled jitted functions seem to be stored in the memory. This can be observed by the high memory footprint of backend_compile
which cannot be seen if the checkpointing is disabled.
A self-consistent example would be:
import jax
import jax.numpy as jnp
@jax.jit
def f(a):
return jnp.sum(a)
@jax.checkpoint
def g(a):
return f(a)
arr = jnp.array([[1,2,3,4,5],[6,7,8,9,10]], dtype=float)
g_v_and_grad = jax.value_and_grad(g)
for i in range(3):
working_arr = arr + i
print(g_v_and_grad(working_arr))
Running the script with env JAX_LOG_COMPILES=1
enabled one can observe:
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0002334117889404297 sec
WARNING:absl:Finished tracing + transforming fn for jit in 0.0003993511199951172 sec
WARNING:absl:Compiling fn (139703279463296 for args (ShapedArray(float32[2,5]), ShapedArray(int32[], weak_type=True)).
WARNING:absl:Finished XLA compilation of fn in 0.04700160026550293 sec
WARNING:absl:Finished tracing + transforming f for jit in 0.0010411739349365234 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00015473365783691406 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209762752 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.04526233673095703 sec
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.00016546249389648438 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00014591217041015625 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209798976 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.00750732421875 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0011491775512695312 sec
WARNING:absl:Compiling backward_pass (139703209802560 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.041948556900024414 sec
(DeviceArray(55., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]], dtype=float32))
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00014543533325195312 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209800384 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007508516311645508 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.0001461505889892578 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209863232 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007668972015380859 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0005974769592285156 sec
WARNING:absl:Compiling backward_pass (139703209362624 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.005425214767456055 sec
(DeviceArray(65., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]], dtype=float32))
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00014638900756835938 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209350720 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007513523101806641 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00015473365783691406 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209372160 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007587909698486328 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0005350112915039062 sec
WARNING:absl:Compiling backward_pass (139703209370048 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.0054433345794677734 sec
(DeviceArray(75., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]], dtype=float32))
Comment out the checkpoint decorator leads to the wanted behavior:
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0002498626708984375 sec
WARNING:absl:Finished tracing + transforming fn for jit in 0.00040721893310546875 sec
WARNING:absl:Compiling fn (140693235752000 for args (ShapedArray(float32[2,5]), ShapedArray(int32[], weak_type=True)).
WARNING:absl:Finished XLA compilation of fn in 0.04748940467834473 sec
WARNING:absl:Finished tracing + transforming f for jit in 0.0010097026824951172 sec
WARNING:absl:Compiling f (140692730754112 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.04457998275756836 sec
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0001583099365234375 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0004944801330566406 sec
WARNING:absl:Compiling backward_pass (140692730730304 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.041858673095703125 sec
(DeviceArray(55., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]], dtype=float32))
(DeviceArray(65., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]], dtype=float32))
(DeviceArray(75., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]], dtype=float32))
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (4 by maintainers)
Top Results From Across the Web
JAX Frequently Asked Questions (FAQ)
For jax.jit() , the function is executed once using the Python interpreter, at which time the Inside printing happens, and the first value...
Read more >Just In Time Compilation with JAX
We will discuss the jax.jit() transform, which will perform Just In Time (JIT) compilation of a JAX Python function so it can be...
Read more >jax.jit - JAX documentation
JAX keeps a weak reference to fun for use as a compilation cache key, so the object fun must be weakly-referenceable. Most Callable...
Read more >jax._src.ad_checkpoint - JAX documentation - Read the Docs
The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a way to trade off computation time and memory cost in the context of automatic ...
Read more >jax.remat / jax.checkpoint changes: what you need to know
The new jax.checkpoint implementation can rematerialize rather than save the value of a . Significantly less Python overhead in some cases ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
This was a bit tricky! There is a problem with grad-of-remat-of-jit causing jit cache misses. These two changes are sufficient to fix it, and I believe both are necessary as well (though see below for alternatives):
trace_to_subjaxpr_dynamic_memoized
function, been meaning to land this for a while…)~The latter is necessary because when round-tripping through a jaxpr (i.e. basically doing eval-jaxpr-of-make-jaxpr, as we often do, including when we linearize) we need to make sure we get cache hits. But we were constructing new opaque callables (with hash/equality defined by object id) every time, meaning we never did get cache hits.
In particular, consider this repro involving
jax.linearize
:Running
env JAX_LOG_COMPILES=1 python linearize_repro.py $N
would show a number of compilations that scale linearly withN
. (I piped stderr intogrep 'XLA compilation' | wc -l
like a pro.)After #10034 the number of compiles becomes constant with
N
.However, while that fixed the issue for
linearize
, it wasn’t quite enough forgrad
(which is basicallylinearize
plus a transposition step). For that, we had to improve the remat transpose rule to support caching. That’s what #10037 does.(TODO explain that the new version of
checkpoint
inad_checkpoint.checkpoint
also needed separate work)TODO talk about the new tooling I’m going to add to explain automatically why recompiles happen
Nevermind, df1c478ec52fb75ce88c06ab0133d9f5263c6767 already fixed the problem for me, I was just on an outdated checkout during my test