Massive VRAM Usage for Feature Ablation and Shapley Value Sampling
See original GitHub issue❓ Questions and Help
Hello, I have just set up to use captum to analyze a resnet50 model (taken directly from torchvision). Captum works perfectly as expected for most of the attribution methods. However, when I tried to used feature ablation and shapley value sampling, I get out of memory errors despite using pertubation_per_eval=1
(the minimum value). Further, I tried to run the same code, but using data parallel with 3 gpus instead, but the program still OOM.
For reference, I am using RTX 8000 gpus with 48 GB of VRAM each. Using feature ablation, for example, the program almost swallowed the entire (roughly, minus a few other programs in the background) 48*3=144 GB of VRAM! Is this expected? Some of the data parallel load balancing was not quite great however, as one GPU only had about 30 GB out of 48 used before one of them OOM. I have a total of 4x RTX 8000 to try, but one is in use currently. Also for reference, the resnet50 model should only take up maybe 5 GB of vram maximum during training. The tensor size is standard imagenet (Nx3x224x224).
Here is a code snippet:
print("Getting feature ablation...")
fig = captum_viz(prep_img, model, None, default_cmap, FeatureAblation, target=1,
perturbations_per_eval=1, show_progress=True)
fig.savefig(file_name_to_export + '_feature_ablation.png')
def captum_viz(prep_img, model, background, default_cmap, method, target=1, **kwargs):
single_img = prep_img.clone().cuda()
class_instance = method(model)
# error is on the instance.attribute line, errors with or without background
if background is not None:
background = background.clone()
attributions_ig = class_instance.attribute(single_img, baselines=background, target=target, **kwargs)
else:
attributions_ig = class_instance.attribute(single_img, target=target, **kwargs)
...
Issue Analytics
- State:
- Created 2 years ago
- Comments:8 (3 by maintainers)
Hello, yes it certainly looks like a memory leak.
I was not able to reproduce the error with colab originally, but I have now narrowed down the problem. For whatever reason, using the model (specifically the nn.Module) generated by wrapping it in a pytorch-lightning wrapper leads to this leak. The fix for this was to create a new resnet50 model fresh from torchvision and copy in the state dict from the pytorch-lightning wrapped version:
Now, when I use model2 instead of model (the pytorch-lightning module which subclasses nn.Module, among other things), there is no memory leak. Weirdly enough, even when I use the model.model (the nn.Module held inside the pytorch-lightning module), the memory leak is still present. Parameter freezing (setting param.requires_grad=False) doesn’t seem to be the difference maker - it’s the only difference between the model2 and model.model that I can think of. The memory leak seems to be related to batch norm since that’s what shows up in the stack trace. I am not sure if this is a pytorch-lightning issue or something else exactly.
Also, on a side note, the show progress feature does not seem to work. But this fix does allow me to run these codes without error, so feel free to close the issue if you see fit. But this is more of a workaround than a fundamental fix.
Hi @vivekmig, I have managed to create a somewhat reproducible example: https://colab.research.google.com/drive/165Zj7wdiFaawbmzGROYqf3__UqY_-DgP?usp=sharing. The issue does not seem to be related to PyTorch Lightning since I can re-create the issue without it. Make sure to switch runtime to GPU in the settings. Basically, if you stop FeatureAblation part way through, you can see that the GPU’s vram increases (doubles!) without being freed. In my colab example, before calling the
.attribute(...)
, around 630 MB of vram is being held. After stopping the.attribute(...)
part way through, around 1400 MB of vram is being held. I believe this issue is also present in ShapleyValueSampling and perhaps some other methods as well.It could be that the vram is freed only whenever the method finishes safely, I am testing that on colab now. Regardless, I don’t think this behavior is expected and it seems possible to consume arbitrary vram by calling and stopping one of these methods. This is likely the source of the both problems I was experiencing.