question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Memory leak when computing the hessian of a jitted function

See original GitHub issue

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax
import jax.numpy as jnp

f = lambda x: jnp.square(x).mean()
jf = jax.jit(f)

x = jax.random.uniform(jax.random.PRNGKey(0), shape=(8, 4))

# Start monitoring memory with htop
while True:
    try:
        # let this run for a bit, notice memory usage doesn't increase
        y = jax.hessian(f)(x)
    except KeyboardInterrupt:
        pass

while True:
    # let this run, notice memory usage starts increasing immediately
    # also, the memory usage doesn't appear to stop increasing ever
    y = jax.hessian(jf)(x)

Also, note that the equivalent code with jax.grad (or jax.jacfwd for multidimensional functions) replacing jax.hessian does not cause a memory leak.

As a result of this leak, my python process ends up getting killed by my OS (due to OOM), so I don’t ever see any python or jax traces.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:10 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
mattjjcommented, Aug 13, 2021

Thanks for ~the question~ raising this, and with a clear reproducer!

At first I thought could be an artifact of asynchronous dispatch. (The jit dispatch path is different from the non-jit dispatch path, so that may play a role in why f and jf behave differently.)

But I ran this version, which forces synchronizations, and it seems like memory is still growing without bound (on the CPU backend):

while True:
    jax.hessian(jf)(x).block_until_ready()

I then tried running the script like this:

env JAX_LOG_COMPILES=1 python issue7621.py

and saw that we’re doing lots of recompiles here:

...
WARNING:absl:Compiling backward_pass (140099626713344) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626782592) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626820800) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626782144) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626862272) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626817664) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626755712) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626837568) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626919552) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626921344) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626884480) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626944512) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626883712) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
WARNING:absl:Compiling backward_pass (140099626911936) for args (ShapedArray(float32[8,4]), ShapedArray(float32[]), ShapedArray(float32[1]), ShapedArray(float32[32,8,4])).
...

Maybe that should’ve been obvious from your clue about grad/jacfwd behaving differently.

Somehow hessian is causing compilation cache misses here…

0reactions
mattjjcommented, Aug 15, 2021

I think it’s an instance of the same bug. Running your code against current HEAD, I only see ‘called’ printed once.

Read more comments on GitHub >

github_iconTop Results From Across the Web

numba jit mode memory leak issue? - python - Stack Overflow
I am currently facing a weird memory leak problem when having jit decorator on a function which uses heapify takings results from another ......
Read more >
Bug when computing Hessian Inverse · Issue #5380 · google/jax
Hello, I'm trying to compute the matrix inverse of the hessian. ... with a single grad. from jax import vmap from jax import...
Read more >
JuMP uses all available memory, can't do anything else after
In my layman's understanding, after it calculates the hessian and all, it stores it in RAM and then the text “This program contain...
Read more >
许力之/jax - Gitee
Here's one way to compose those to make a function that efficiently computes full Hessian matrices: from jax import jit, jacfwd, jacrev def...
Read more >
Working with FuncTorch: An Introduction - Wandb
Working with JAX-like composable function transforms in PyTorch. ... Efficiently computing Jacobians and Hessians in a batched manner.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found