A question about captum in graph classification tasks
See original GitHub issue🐛 Describe the bug
I am running a graph classification task and when I finish training the model, I run:
captum_model = to_captum(model, mask_type='edge')
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)
ig = IntegratedGradients(captum_model)
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=data.y, additional_forward_args=(data),internal_batch_size=64)
data is a batch in my train_loader. I got the error: “Dimension 0 of input should be 1”. But the edge_mask.unsqueeze(0) is of dimension (1,num_edges). Could you help me with this problem? Thanks!
Environment
- PyG version: 1.10
- PyTorch version:2.03
- OS:
- Python version: 3.8
- CUDA/cuDNN version:113
- How you installed PyTorch and PyG (
conda
,pip
, source): conda - Any other relevant information (e.g., version of
torch-scatter
):
Issue Analytics
- State:
- Created 2 years ago
- Comments:13 (6 by maintainers)
Top Results From Across the Web
Some questions about dgl&captum - Deep Graph Library
When doing node classification tasks, we use captum's IG algorithm to interpret the features of the edges. What is the meaning of the...
Read more >Captum vs GNNExplainer for explainability in Graph Neural ...
I'm new to Graph Neural Networks and interested in exploring frameworks that allow the identification of nodes/edges that underlie ...
Read more >Tutorials - Captum · Model Interpretability for PyTorch
This tutorial demonstrates how to apply TCAV algorithm for a NLP task using movie rating dataset and a CNN-based binary sentiment classification model....
Read more >Colab Notebooks and Video Tutorials - PyTorch Geometric
Scaling Graph Neural Networks · Point Cloud Classification with Graph Neural Networks · Explaining GNN Model Predictions using Captum · Customizing Aggregations ...
Read more >Captum: A unified and generic model interpretability ... - arXiv
graph -structured models built on Neural Networks (NN). In this paper we ... image classification tasks. ... and PyTorch related questions.
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
You’re right! Thanks. I had seen this issue with NLP in captum but didn’t make the connection. To make this work I created a module to return the embeddings using a forward hook and subclassed my model to rewrite the forward function while keeping my init unchanged. For embedding extraction I found this helpful example in a forum
which allowed me to extract the trained embeddings via
now I just had to load the weights shared between the reduced subclassed model and full model
Thanks for your help!
Thanks, I tried the example: https://colab.research.google.com/drive/1fLJbFPz0yMCQg81DdCP5I8jXw9LoggKO?usp=sharing#scrollTo=9Hh3YNASuYxm
It generates explanation for one selected graph.