Grad-CAM with Pythia
See original GitHub issueHi again,
I integrated Grad-CAM into the vqa-model pythia. Pythia is a pipeline of multiple networks (multiple inputs, one output). For me the interesting route through the network is the following: First the input image is feed through a resnet152. The resulting features and features from other networks are then feed through the pythia network. The result is the prediction. I am currently doing it like this that I wrapped GradCAM around the resnet152 model and the pythia model seperatly. You sadly cannot wrap both together 😕 Then I call backward only on the pythia model and call generate at layer ‘7’ of the resnet152 model.
Code looks similar to this:
resnet125_model = GradCAM(model=resnet125_model)
pythia_model= GradCAM(model=pythia_model)
features = resnet125_model.forward(input_image)
# Add features from all networks to feature list
# ...
probs, ids = pythia_model.forward(feature_list)
pythia_model.backward(ids=ids [:, [0]])
regions = resnet125_model.generate(target_layer="7")
Is this approach correct? Because I am getting for most images seemingly correct results. Should I do resnet125_model.backward too and if yes how?
Issue Analytics
- State:
- Created 4 years ago
- Comments:6 (1 by maintainers)
Top GitHub Comments
Of course. I added the pythia.yaml to the gist. However it seems that I lost the detectron_model.yaml…
There are two different detectron_model.yaml files available online maybe they can help you as a guide: https://dl.fbaipublicfiles.com/pythia/detectron_model/e2e_faster_rcnn_X-101-64x4d-FPN_1x_MLP_2048_FPN_512.yaml (from this post https://github.com/facebookresearch/mmf/issues/30) https://dl.fbaipublicfiles.com/pythia/detectron_model/detectron_model.yaml (from this post https://github.com/facebookresearch/mmf/issues/100)
The Pythia implementation I used was from this post originally: https://github.com/facebookresearch/mmf/issues/204 Maybe the guy who posted it can help you as well 😃
Best Karol
Hi,
sure, no problem. I have put the necessary classes into a gist https://gist.github.com/Karol-G/29a63098b07b79b6cbfad2f8e8a69da4 However I did this as a semester project and the code grew over time. So the code is very ugly and there is no documentation. I still hope it helps you though.
I also (coincidentally) released a framework just today https://github.com/Karol-G/Gcam for easy usage of Grad-Cam and etc. However the framework is only tested with classification and segmentation so I don’t know if it will work as easily with VQA too.
Best Karol