Add Grad-CAM implementation
See original GitHub issueš Feature
Implementation of Grad-CAM, probably as a vision callback.
Motivation
Grad-CAM is widely used localization method that uses the gradient information flowing into the last convolutional layer of a CNN to visualize the modelās prediction by highlighting the āimportantā pixels in the image.
The technique does not require any modifications to the existing model architecture, so can be applied to any CNN-based architecture, including those for image captioning and visual question answering.
Pitch
Right now, for research weāre doing in the lab Iām working in, Iāve been using a modified version of this PyTorch implementation of Grad-CAM, which only works on batch_size = 1 (but it looks like there are many other PyTorch implementations, with many stars, on GitHub that we could work off of).
For the above research, Iāve added an if statement to the test_step function in our LightningModule so that if we want the cams to be saved during inference, it calls a separate util function localize that does the forward and backward pass to create the feature maps. Iām not sure this is ideal, though, because we later do another forward and backward pass on the same image to get the prediction, so there is duplicated work.
I was thinking that it would be nice to have some sort of callback that can just generate (and save?) the cams for you, without having to mess with the training pipeline. I guess we would have to figure out where the the cams would be saved, and whether just the heatmap would be saved, or the heatmap overlaid on the original image (which is probably the most helpful?).
Alternatives
Other localization methods include Integrated Gradients, WILDCAT, and Grad-CAM++, but Grad-CAM seems to be the most widely-used.
Additional context
Iām not sure if this is helpful, but to clarify further how Iām currently doing this in Lightning: Iāve created inference_step and inference_epoch_end functions that both my valid functions and my test functions in the LightningModule call (that way, we can make sure that both valid and test are doing inference in the same way). Only my test_step has a separate if statement thatās called only if the user wants to also generate cams.
Anyway, Iād love to help out on this in any way I can! Iāve never written a callback before, though, so would need some guidance on how to approach that.
Issue Analytics
- State:
- Created 3 years ago
- Reactions:5
- Comments:5

Top Related StackOverflow Question
How about integrating captum for the same? Theyāve wide range of interpretability alogorithms implemented with rigorous testing. We could simply write a Callback for captum and add ācaptumā as explicit dependency, throwing exception at runtime. Captum also provides utility functions for visualization so itāll save a lot of efforts if we just use the library.
cc: @ASaporta
@ASaporta in case this goes through, after a quick look at the code I see at least a couple of potential places where some utilities from kornia could be used here,
resizeandnormalize_min_max