question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

A question about captum in graph classification tasks

See original GitHub issue

🐛 Describe the bug

I am running a graph classification task and when I finish training the model, I run:

captum_model = to_captum(model, mask_type='edge')
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

ig = IntegratedGradients(captum_model)
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0), target=data.y, additional_forward_args=(data),internal_batch_size=64)

data is a batch in my train_loader. I got the error: “Dimension 0 of input should be 1”. But the edge_mask.unsqueeze(0) is of dimension (1,num_edges). Could you help me with this problem? Thanks!

Environment

  • PyG version: 1.10
  • PyTorch version:2.03
  • OS:
  • Python version: 3.8
  • CUDA/cuDNN version:113
  • How you installed PyTorch and PyG (conda, pip, source): conda
  • Any other relevant information (e.g., version of torch-scatter):

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:13 (6 by maintainers)

github_iconTop GitHub Comments

2reactions
bgeiercommented, Apr 7, 2022

You’re right! Thanks. I had seen this issue with NLP in captum but didn’t make the connection. To make this work I created a module to return the embeddings using a forward hook and subclassed my model to rewrite the forward function while keeping my init unchanged. For embedding extraction I found this helpful example in a forum

from typing import Dict, Iterable, Callable
from torch import nn 
from torch import Tensor

# create a class to return embeddings from trained embedding layers using a forward hook
# return s a dictionary of layer:tensor
class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, batch: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x, edge_index, edge_attr, batch)
        return self._features

which allowed me to extract the trained embeddings via

embedder = FeatureExtractor(model=model, layers=['ex','ey','ez']) 
emb_dict = embedder(data.x, data.edge_index, data.edge_attr, data.batch)

now I just had to load the weights shared between the reduced subclassed model and full model

# define a new initialized model for just GNN portion, or post integer embedding 
gnn = newModel(emb_x, **params) # this is a subclassed model with a forward that starts with embedding cat
# get pretrained weights 
pretrained_dict = checkpoint['model_state_dict'] 
# find dict of subclasses instance 
model_dict = gnn.state_dict() 
# 1. filter out unnecessary keys 
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 
# 2. overwrite entries in the existing state dict 
model_dict.update(pretrained_dict) 
# 3. load the new state dict 
gnn.load_state_dict(model_dict)

captum_model = pyg.nn.models.explainer.to_captum(gnn, mask_type='node')
emb_dict = embedder(data.x, data.edge_index, data.edge_attr, data.batch)
x = torch.cat((emb_dict['ex'],emb_dict['ey'],emb_dict['ez']), dim=1)

ig = IntegratedGradients(captum_model) 
ig_attr = ig.attribute( x.unsqueeze(0), target=int(data.y), additional_forward_args=(data.edge_index, data.edge_attr, data.batch), internal_batch_size=1, ) 
node_mask = np.abs(ig_attr[0].cpu().detach().numpy())

Thanks for your help!

1reaction
juanshu30commented, Mar 22, 2022

Currently, you can’t use to_captum for graph classification batch-wise. Does the method work if you select only one graph and specify the internal_batch_size = 1? Setting output_idx = 0 should not be necessary.

Also set additional_forward_args = (data.x, data.edge_index).

Thanks, I tried the example: https://colab.research.google.com/drive/1fLJbFPz0yMCQg81DdCP5I8jXw9LoggKO?usp=sharing#scrollTo=9Hh3YNASuYxm

It generates explanation for one selected graph.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Some questions about dgl&captum - Deep Graph Library
When doing node classification tasks, we use captum's IG algorithm to interpret the features of the edges. What is the meaning of the...
Read more >
Captum vs GNNExplainer for explainability in Graph Neural ...
I'm new to Graph Neural Networks and interested in exploring frameworks that allow the identification of nodes/edges that underlie ...
Read more >
Tutorials - Captum · Model Interpretability for PyTorch
This tutorial demonstrates how to apply TCAV algorithm for a NLP task using movie rating dataset and a CNN-based binary sentiment classification model....
Read more >
Colab Notebooks and Video Tutorials - PyTorch Geometric
Scaling Graph Neural Networks · Point Cloud Classification with Graph Neural Networks · Explaining GNN Model Predictions using Captum · Customizing Aggregations ...
Read more >
Captum: A unified and generic model interpretability ... - arXiv
graph -structured models built on Neural Networks (NN). In this paper we ... image classification tasks. ... and PyTorch related questions.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found