question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Massive VRAM Usage for Feature Ablation and Shapley Value Sampling

See original GitHub issue

❓ Questions and Help

Hello, I have just set up to use captum to analyze a resnet50 model (taken directly from torchvision). Captum works perfectly as expected for most of the attribution methods. However, when I tried to used feature ablation and shapley value sampling, I get out of memory errors despite using pertubation_per_eval=1 (the minimum value). Further, I tried to run the same code, but using data parallel with 3 gpus instead, but the program still OOM.

For reference, I am using RTX 8000 gpus with 48 GB of VRAM each. Using feature ablation, for example, the program almost swallowed the entire (roughly, minus a few other programs in the background) 48*3=144 GB of VRAM! Is this expected? Some of the data parallel load balancing was not quite great however, as one GPU only had about 30 GB out of 48 used before one of them OOM. I have a total of 4x RTX 8000 to try, but one is in use currently. Also for reference, the resnet50 model should only take up maybe 5 GB of vram maximum during training. The tensor size is standard imagenet (Nx3x224x224).

Here is a code snippet:

    print("Getting feature ablation...")
    fig = captum_viz(prep_img, model, None, default_cmap, FeatureAblation, target=1,
                     perturbations_per_eval=1, show_progress=True)
    fig.savefig(file_name_to_export + '_feature_ablation.png')
def captum_viz(prep_img, model, background, default_cmap, method, target=1, **kwargs):
    single_img = prep_img.clone().cuda()
    class_instance = method(model)
    # error is on the instance.attribute line, errors with or without background
    if background is not None:
        background = background.clone()
        attributions_ig = class_instance.attribute(single_img, baselines=background, target=target, **kwargs)
    else:
        attributions_ig = class_instance.attribute(single_img, target=target, **kwargs)

...

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:8 (3 by maintainers)

github_iconTop GitHub Comments

2reactions
ndalton12commented, Apr 6, 2021

Hello, yes it certainly looks like a memory leak.

I was not able to reproduce the error with colab originally, but I have now narrowed down the problem. For whatever reason, using the model (specifically the nn.Module) generated by wrapping it in a pytorch-lightning wrapper leads to this leak. The fix for this was to create a new resnet50 model fresh from torchvision and copy in the state dict from the pytorch-lightning wrapped version:

    import torchvision
    model2 = torchvision.models.resnet50()
    num_ftrs = model2.fc.in_features
    model2.fc = torch.nn.Linear(num_ftrs, 2)
    model2.classifier = model2.fc
    model2.load_state_dict(model.model.state_dict())
    model2.cuda()

Now, when I use model2 instead of model (the pytorch-lightning module which subclasses nn.Module, among other things), there is no memory leak. Weirdly enough, even when I use the model.model (the nn.Module held inside the pytorch-lightning module), the memory leak is still present. Parameter freezing (setting param.requires_grad=False) doesn’t seem to be the difference maker - it’s the only difference between the model2 and model.model that I can think of. The memory leak seems to be related to batch norm since that’s what shows up in the stack trace. I am not sure if this is a pytorch-lightning issue or something else exactly.

Also, on a side note, the show progress feature does not seem to work. But this fix does allow me to run these codes without error, so feel free to close the issue if you see fit. But this is more of a workaround than a fundamental fix.

1reaction
ndalton12commented, Apr 11, 2021

Hi @vivekmig, I have managed to create a somewhat reproducible example: https://colab.research.google.com/drive/165Zj7wdiFaawbmzGROYqf3__UqY_-DgP?usp=sharing. The issue does not seem to be related to PyTorch Lightning since I can re-create the issue without it. Make sure to switch runtime to GPU in the settings. Basically, if you stop FeatureAblation part way through, you can see that the GPU’s vram increases (doubles!) without being freed. In my colab example, before calling the .attribute(...), around 630 MB of vram is being held. After stopping the .attribute(...) part way through, around 1400 MB of vram is being held. I believe this issue is also present in ShapleyValueSampling and perhaps some other methods as well.

It could be that the vram is freed only whenever the method finishes safely, I am testing that on colab now. Regardless, I don’t think this behavior is expected and it seems possible to consume arbitrary vram by calling and stopping one of these methods. This is likely the source of the both problems I was experiencing.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Shapley value attributed ablation with augmented learning for ...
In this paper, we propose Shapley Attributed Ablation with Augmented Learning: ShapAAL, which demonstrates that deep learning algorithm with ...
Read more >
Scalable Interpretability via Polynomials - arXiv
We present a new class of GAMs that use tensor rank decompositions of polynomials to learn powerful, inherently-interpretable models. Our ...
Read more >
Scalable Interpretability via Polynomials - DeepAI
In the age of big data and interpretable machine learning, ... Naively multiplying the two will give a feature value of xixj=0.3 (smaller ......
Read more >
Multiple instance classification: Review, taxonomy and ...
Under instance space paradigm, the inference of instance scores can be inaccurate because current non-trainable MIL pooling operators can be not ...
Read more >
A Class of Augmented Convolutional Networks Architectures ...
visual anomaly detection involve massive, complex, inefficient models whose ... based methods that use a variant of AutoEncoders and generative methods like ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found