Memory leak when computing the hessian of a jitted function
See original GitHub issuePlease:
- Check for duplicate issues.
- Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax
import jax.numpy as jnp
f = lambda x: jnp.square(x).mean()
jf = jax.jit(f)
x = jax.random.uniform(jax.random.PRNGKey(0), shape=(8, 4))
# Start monitoring memory with htop
while True:
try:
# let this run for a bit, notice memory usage doesn't increase
y = jax.hessian(f)(x)
except KeyboardInterrupt:
pass
while True:
# let this run, notice memory usage starts increasing immediately
# also, the memory usage doesn't appear to stop increasing ever
y = jax.hessian(jf)(x)
Also, note that the equivalent code with jax.grad
(or jax.jacfwd
for multidimensional functions) replacing jax.hessian
does not cause a memory leak.
As a result of this leak, my python process ends up getting killed by my OS (due to OOM), so I don’t ever see any python or jax traces.
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (8 by maintainers)
Top Results From Across the Web
numba jit mode memory leak issue? - python - Stack Overflow
I am currently facing a weird memory leak problem when having jit decorator on a function which uses heapify takings results from another ......
Read more >Bug when computing Hessian Inverse · Issue #5380 · google/jax
Hello, I'm trying to compute the matrix inverse of the hessian. ... with a single grad. from jax import vmap from jax import...
Read more >JuMP uses all available memory, can't do anything else after
In my layman's understanding, after it calculates the hessian and all, it stores it in RAM and then the text “This program contain...
Read more >许力之/jax - Gitee
Here's one way to compose those to make a function that efficiently computes full Hessian matrices: from jax import jit, jacfwd, jacrev def...
Read more >Working with FuncTorch: An Introduction - Wandb
Working with JAX-like composable function transforms in PyTorch. ... Efficiently computing Jacobians and Hessians in a batched manner.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Thanks for ~the question~ raising this, and with a clear reproducer!
At first I thought could be an artifact of asynchronous dispatch. (The
jit
dispatch path is different from the non-jit
dispatch path, so that may play a role in whyf
andjf
behave differently.)But I ran this version, which forces synchronizations, and it seems like memory is still growing without bound (on the CPU backend):
I then tried running the script like this:
and saw that we’re doing lots of recompiles here:
Maybe that should’ve been obvious from your clue about
grad
/jacfwd
behaving differently.Somehow
hessian
is causing compilation cache misses here…I think it’s an instance of the same bug. Running your code against current HEAD, I only see ‘called’ printed once.