question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Memory leak when repeatedly saving a variable during a host_callback call.

See original GitHub issue

I have a use case where I want to use host_callback to save a python object (specifically a scipy SuperLU object) and update it over and over again within a jitted function. These should persist between jit calls. Unfortunately I am running into a memory leak. I have a minimal example below, where the memory consumptions blows up linearly as the code runs. Even if I insert gc.collect() in the host_callback function, it seems that the SuperLU objects are somehow being kept around.

import numpy as np
import scipy.sparse
import scipy.sparse.linalg

import jax.experimental.host_callback as hcb
import jax
import jax.numpy as jnp

class SolveFnContainer:
    lu = None

def host_do_something(_):
    print('solving')

    n = 100000
    A = scipy.sparse.diags((-1*np.ones(n-1), -1*np.ones(n-1), 2*np.ones(n)), offsets=(-1, 1, 0), format='csc')
    SolveFnContainer.lu = scipy.sparse.linalg.spilu(A)
    return 0.0

def device_do_something():
    inputs = (0.0,)
    return hcb.call(host_do_something, inputs,
                    result_shape=jax.ShapeDtypeStruct((), np.float64))

def main():
    def body_fn(x):
        return device_do_something()

    def cond_fn(x):
        return x == 0.0

    return jax.lax.while_loop(cond_fn, body_fn, 0.0)

if __name__ == '__main__':
    jit_main = jax.jit(main)
    print(jit_main())

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
hawkinspcommented, Apr 1, 2022

I think this is a scipy bug. Here’s a reproduction without any JAX pieces:

from time import sleep

import numpy as np
import scipy.sparse
import scipy.sparse.linalg

import os
import psutil
import gc
import threading


class Container:
    A = None


def host_do_something(_):
    n = 100000
    A = scipy.sparse.diags((-1*np.ones(n-1), -1*np.ones(n-1), 2*np.ones(n)), offsets=(-1, 1, 0), format='csc')
    Container.A = scipy.sparse.linalg.spilu(A)
    #Container.A = A
    return 0.0


def thread_do_something(_):
    t = threading.Thread(target=host_do_something, args=(0.0,))
    t.start()
    t.join()

if __name__ == '__main__':
    while True:
        thread_do_something(0.0)

        print(psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)
        sleep(0.2)
        gc.collect()

It appears the leak occurs because JAX is running your callback on another thread. For some reason this triggers a scipy leak. I don’t know what’s happening without debugging scipy, but at this point it’s not really a JAX problem.

As to why this happens under JAX on GPU but not CPU: JAX uses different callback implementations on CPU and GPU, and only on GPU is it calling back on another thread. (As it happens, @sharadmv is working on a GPU callback implementation that uses only one thread, but that is ultimately papering over the problem I think.)

Hope that helps!

0reactions
denizoktcommented, Apr 1, 2022

Thanks so much! I will look into this and redirect to scipy if needed.

Read more comments on GitHub >

github_iconTop Results From Across the Web

4 Types of Memory Leaks in JavaScript and How to Get Rid ...
In this article we will explore common types of memory leaks in client-side JavaScript code. We will also learn how to use the...
Read more >
Huge memory leak in repeated os.path.isdir calls?
The root cause is a failure to call PyMem_Free on the path variable in the non-Unicode path: if (!PyArg_ParseTuple(args, "et:_isdir", ...
Read more >
Causes of Memory Leaks in JavaScript and How to Avoid ...
Keeping redundant objects in memory results in excessive memory use inside the app and can lead to degraded and poor performance.
Read more >
The Secrets of Memory Leaks in JavaScript You Don't Know
When the garbage collector traversal the call stack, it finds that variables b and c are not used, so it determines that they...
Read more >
How To Detect and Prevent Memory Leaks
A memory leak in an application deployed on the cloud can affect the ... repeatedly fails to return memory it obtained for temporary...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found