Memory leak when repeatedly saving a variable during a host_callback call.
See original GitHub issueI have a use case where I want to use host_callback to save a python object (specifically a scipy SuperLU object) and update it over and over again within a jitted function. These should persist between jit calls. Unfortunately I am running into a memory leak. I have a minimal example below, where the memory consumptions blows up linearly as the code runs. Even if I insert gc.collect() in the host_callback function, it seems that the SuperLU objects are somehow being kept around.
import numpy as np
import scipy.sparse
import scipy.sparse.linalg
import jax.experimental.host_callback as hcb
import jax
import jax.numpy as jnp
class SolveFnContainer:
lu = None
def host_do_something(_):
print('solving')
n = 100000
A = scipy.sparse.diags((-1*np.ones(n-1), -1*np.ones(n-1), 2*np.ones(n)), offsets=(-1, 1, 0), format='csc')
SolveFnContainer.lu = scipy.sparse.linalg.spilu(A)
return 0.0
def device_do_something():
inputs = (0.0,)
return hcb.call(host_do_something, inputs,
result_shape=jax.ShapeDtypeStruct((), np.float64))
def main():
def body_fn(x):
return device_do_something()
def cond_fn(x):
return x == 0.0
return jax.lax.while_loop(cond_fn, body_fn, 0.0)
if __name__ == '__main__':
jit_main = jax.jit(main)
print(jit_main())
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (5 by maintainers)
Top Results From Across the Web
4 Types of Memory Leaks in JavaScript and How to Get Rid ...
In this article we will explore common types of memory leaks in client-side JavaScript code. We will also learn how to use the...
Read more >Huge memory leak in repeated os.path.isdir calls?
The root cause is a failure to call PyMem_Free on the path variable in the non-Unicode path: if (!PyArg_ParseTuple(args, "et:_isdir", ...
Read more >Causes of Memory Leaks in JavaScript and How to Avoid ...
Keeping redundant objects in memory results in excessive memory use inside the app and can lead to degraded and poor performance.
Read more >The Secrets of Memory Leaks in JavaScript You Don't Know
When the garbage collector traversal the call stack, it finds that variables b and c are not used, so it determines that they...
Read more >How To Detect and Prevent Memory Leaks
A memory leak in an application deployed on the cloud can affect the ... repeatedly fails to return memory it obtained for temporary...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
I think this is a scipy bug. Here’s a reproduction without any JAX pieces:
It appears the leak occurs because JAX is running your callback on another thread. For some reason this triggers a scipy leak. I don’t know what’s happening without debugging scipy, but at this point it’s not really a JAX problem.
As to why this happens under JAX on GPU but not CPU: JAX uses different callback implementations on CPU and GPU, and only on GPU is it calling back on another thread. (As it happens, @sharadmv is working on a GPU callback implementation that uses only one thread, but that is ultimately papering over the problem I think.)
Hope that helps!
Thanks so much! I will look into this and redirect to scipy if needed.