Unexpectedly high grad-of-scan memory usage
See original GitHub issueConsider the following function that sums x * (y + z)
over all y
in ys
and then averages over the resulting matrix of sums:
import jax.lax
import jax.numpy as jnp
def f(x, ys):
z = jnp.ones((3000, 3000))
def scanned(carry, y):
return carry + x * (y + z), None
summed, _ = jax.lax.scan(scanned, jnp.zeros_like(z), ys)
return summed.mean()
Because I use lax.scan
(instead of, e.g., vmap
or lax.map
followed by a sum over the first axis), memory usage doesn’t significantly scale with the number of ys
. The following code uses ~203MB regardless of whether n = 5
or n = 10
:
import resource
print(f(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")
But the gradient uses 557MB for n = 5
and 908MB for n = 10
:
import jax
print(jax.grad(f)(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")
The story is similar when these functions are jit
ted.
My best guess about what’s going on here is that grad
is storing every (y + z)
in memory. Is this intended? And is there some way to tell grad
to be more economical about what it stores in memory to achieve a similar lax.scan
memory reduction when computing the gradient?
Issue Analytics
- State:
- Created 3 years ago
- Reactions:1
- Comments:5 (1 by maintainers)
Top Results From Across the Web
Seemingly abnormally high memory usage even when ...
After doing the most recent windows update (I think I did it on Sunday) I have seemingly quite high memory usage when not...
Read more >Fix High RAM Memory Usage Issue on Windows 11/10 [10 ...
10 fixes are available here to help you fix Windows 11/10 high RAM/CPU memory usage problems, making your computer back to work smoothly...
Read more >GTK4 memory usage unexpectedly high : r/GTK - Reddit
Memory consumption in GTK4 is definitely not normal. An empty window consumes 20MB! For comparison, a blank window on GTK3 uses 5.5mb and...
Read more >Unexpectedly high memory usage when running CUFFT.ifft()
I want to use CUDA.jl instead of CUDA C/C++ on Jetson nano (Single-board computer with GPU), but I am puzzled by the inexplicable...
Read more >Hapi server suddenly grows memory usage nonstop
I have a feeling that calls are made to the database but take a long time to return, and thus things start to...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Very cool, I’ll keep my eyes peeled and keep updating the package. The work you all are doing here is really great.
By the way, we’re working on some other improvements that should make this work well even without
remat
by never instantiating the largeones((3000, 3000))
array. We’d still needremat
in general, but in this case the memory savings can be had by avoiding the large constant.