question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Using `jax.jit` inside a function decorated by `jax.checkpoint` causes recompilation every time

See original GitHub issue

Using a jitted function inside a function decorated by jax.checkpoint causes a lot of extra compilations even if the arguments still have the same shape. Calculating the gradient for such a function causes a memory leak in the long rung since all the compiled jitted functions seem to be stored in the memory. This can be observed by the high memory footprint of backend_compile which cannot be seen if the checkpointing is disabled.

A self-consistent example would be:

import jax
import jax.numpy as jnp

@jax.jit
def f(a):
    return jnp.sum(a)

@jax.checkpoint
def g(a):
    return f(a)

arr = jnp.array([[1,2,3,4,5],[6,7,8,9,10]], dtype=float)

g_v_and_grad = jax.value_and_grad(g)

for i in range(3):
    working_arr = arr + i
    print(g_v_and_grad(working_arr))

Running the script with env JAX_LOG_COMPILES=1 enabled one can observe:

WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0002334117889404297 sec
WARNING:absl:Finished tracing + transforming fn for jit in 0.0003993511199951172 sec
WARNING:absl:Compiling fn (139703279463296 for args (ShapedArray(float32[2,5]), ShapedArray(int32[], weak_type=True)).
WARNING:absl:Finished XLA compilation of fn in 0.04700160026550293 sec
WARNING:absl:Finished tracing + transforming f for jit in 0.0010411739349365234 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00015473365783691406 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209762752 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.04526233673095703 sec
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.00016546249389648438 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00014591217041015625 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209798976 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.00750732421875 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0011491775512695312 sec
WARNING:absl:Compiling backward_pass (139703209802560 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.041948556900024414 sec
(DeviceArray(55., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00014543533325195312 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209800384 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007508516311645508 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.0001461505889892578 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209863232 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007668972015380859 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0005974769592285156 sec
WARNING:absl:Compiling backward_pass (139703209362624 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.005425214767456055 sec
(DeviceArray(65., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00014638900756835938 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209350720 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007513523101806641 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00015473365783691406 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209372160 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007587909698486328 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0005350112915039062 sec
WARNING:absl:Compiling backward_pass (139703209370048 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.0054433345794677734 sec
(DeviceArray(75., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))

Comment out the checkpoint decorator leads to the wanted behavior:

WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0002498626708984375 sec
WARNING:absl:Finished tracing + transforming fn for jit in 0.00040721893310546875 sec
WARNING:absl:Compiling fn (140693235752000 for args (ShapedArray(float32[2,5]), ShapedArray(int32[], weak_type=True)).
WARNING:absl:Finished XLA compilation of fn in 0.04748940467834473 sec
WARNING:absl:Finished tracing + transforming f for jit in 0.0010097026824951172 sec
WARNING:absl:Compiling f (140692730754112 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.04457998275756836 sec
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0001583099365234375 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0004944801330566406 sec
WARNING:absl:Compiling backward_pass (140692730730304 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.041858673095703125 sec
(DeviceArray(55., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))
(DeviceArray(65., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))
(DeviceArray(75., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:8 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Mar 25, 2022

This was a bit tricky! There is a problem with grad-of-remat-of-jit causing jit cache misses. These two changes are sufficient to fix it, and I believe both are necessary as well (though see below for alternatives):

  1. #10037 ~#9181 (specifically the trace_to_subjaxpr_dynamic_memoized function, been meaning to land this for a while…)~
  2. #10034

The latter is necessary because when round-tripping through a jaxpr (i.e. basically doing eval-jaxpr-of-make-jaxpr, as we often do, including when we linearize) we need to make sure we get cache hits. But we were constructing new opaque callables (with hash/equality defined by object id) every time, meaning we never did get cache hits.

In particular, consider this repro involving jax.linearize:

# linearize_repro.py
import sys
import jax

identity = jax.checkpoint(jax.jit(lambda x: 2 * x))

_, f_lin = jax.linearize(identity, 1.)
for _ in range(int(sys.argv[1])):
  f_lin(1.)

Running env JAX_LOG_COMPILES=1 python linearize_repro.py $N would show a number of compilations that scale linearly with N. (I piped stderr into grep 'XLA compilation' | wc -l like a pro.)

After #10034 the number of compiles becomes constant with N.

However, while that fixed the issue for linearize, it wasn’t quite enough for grad (which is basically linearize plus a transposition step). For that, we had to improve the remat transpose rule to support caching. That’s what #10037 does.

(TODO explain that the new version of checkpoint in ad_checkpoint.checkpoint also needed separate work)

TODO talk about the new tooling I’m going to add to explain automatically why recompiles happen

0reactions
JanLucacommented, Apr 6, 2022

Nevermind, df1c478ec52fb75ce88c06ab0133d9f5263c6767 already fixed the problem for me, I was just on an outdated checkout during my test

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
For jax.jit() , the function is executed once using the Python interpreter, at which time the Inside printing happens, and the first value...
Read more >
Just In Time Compilation with JAX
We will discuss the jax.jit() transform, which will perform Just In Time (JIT) compilation of a JAX Python function so it can be...
Read more >
jax.jit - JAX documentation
JAX keeps a weak reference to fun for use as a compilation cache key, so the object fun must be weakly-referenceable. Most Callable...
Read more >
jax._src.ad_checkpoint - JAX documentation - Read the Docs
The :func:`jax.checkpoint` decorator, aliased to :func:`jax.remat`, provides a way to trade off computation time and memory cost in the context of automatic ...
Read more >
jax.remat / jax.checkpoint changes: what you need to know
The new jax.checkpoint implementation can rematerialize rather than save the value of a . Significantly less Python overhead in some cases ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found