question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Unexpectedly high grad-of-scan memory usage

See original GitHub issue

Consider the following function that sums x * (y + z) over all y in ys and then averages over the resulting matrix of sums:

import jax.lax
import jax.numpy as jnp

def f(x, ys):
    z = jnp.ones((3000, 3000))

    def scanned(carry, y):
        return carry + x * (y + z), None

    summed, _ = jax.lax.scan(scanned, jnp.zeros_like(z), ys)
    return summed.mean()

Because I use lax.scan (instead of, e.g., vmap or lax.map followed by a sum over the first axis), memory usage doesn’t significantly scale with the number of ys. The following code uses ~203MB regardless of whether n = 5 or n = 10:

import resource

print(f(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")

But the gradient uses 557MB for n = 5 and 908MB for n = 10:

import jax

print(jax.grad(f)(1.0, jnp.ones(n)))
print(f"{1e-3 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}MB")

The story is similar when these functions are jitted.

My best guess about what’s going on here is that grad is storing every (y + z) in memory. Is this intended? And is there some way to tell grad to be more economical about what it stores in memory to achieve a similar lax.scan memory reduction when computing the gradient?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Reactions:1
  • Comments:5 (1 by maintainers)

github_iconTop GitHub Comments

1reaction
jeffgortmakercommented, May 22, 2020

Very cool, I’ll keep my eyes peeled and keep updating the package. The work you all are doing here is really great.

1reaction
mattjjcommented, May 22, 2020

By the way, we’re working on some other improvements that should make this work well even without remat by never instantiating the large ones((3000, 3000)) array. We’d still need remat in general, but in this case the memory savings can be had by avoiding the large constant.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Seemingly abnormally high memory usage even when ...
After doing the most recent windows update (I think I did it on Sunday) I have seemingly quite high memory usage when not...
Read more >
Fix High RAM Memory Usage Issue on Windows 11/10 [10 ...
10 fixes are available here to help you fix Windows 11/10 high RAM/CPU memory usage problems, making your computer back to work smoothly...
Read more >
GTK4 memory usage unexpectedly high : r/GTK - Reddit
Memory consumption in GTK4 is definitely not normal. An empty window consumes 20MB! For comparison, a blank window on GTK3 uses 5.5mb and...
Read more >
Unexpectedly high memory usage when running CUFFT.ifft()
I want to use CUDA.jl instead of CUDA C/C++ on Jetson nano (Single-board computer with GPU), but I am puzzled by the inexplicable...
Read more >
Hapi server suddenly grows memory usage nonstop
I have a feeling that calls are made to the database but take a long time to return, and thus things start to...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found