question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Edge masking in GNNExplainer implementation

See original GitHub issue

Question

I found that edge masking seems to be missing in implementation of GNNExplainer in PyG.

for epoch in range(1, self.epochs + 1):
    optimizer.zero_grad()
    h = x * self.node_feat_mask.sigmoid()
    out = self.model(x=h, edge_index=edge_index, **kwargs)
    if self.return_type == 'regression':
        loss = self.__loss__(mapping, out, prediction)
    else:
        log_logits = self.__to_log_prob__(out)
        loss = self.__loss__(mapping, log_logits, pred_label)
    loss.backward()
    optimizer.step()

This is the loop of explaining a node, line 286-298 in the source code. If my understanding is correct, only node features are masked by h = x * self.node_feat_mask.sigmoid(), while edge_index, i.e., the adjacency matrix is not.

In the official implementation by the author of GNNExplainer, it is done by the _masked_adj() function. Due to the difference in data structure, I modify the code as below to do edge masking:

for epoch in range(1, self.epochs + 1):
    optimizer.zero_grad()
    h = x * self.node_feat_mask.sigmoid()
    ### fix
    edge_weight = ori_edge_weight * self.edge_mask.sigmoid()
    out = self.model(x=h, edge_index=edge_index, edge_weight = edge_weight)
    ###
    # out = self.model(x=h, edge_index=edge_index, **kwargs)
    if self.return_type == 'regression':
        loss = self.__loss__(mapping, out, prediction)
    else:
        log_logits = self.__to_log_prob__(out)
        loss = self.__loss__(mapping, log_logits, pred_label)
    loss.backward()
    optimizer.step()

and add

if 'edge_weight' in kwargs:
    # print(kwargs['edge_weight'].shape)
    ori_edge_weight = kwargs['edge_weight']
    # print('edge_weight is not None')
else:
    ori_edge_weight = torch.ones((edge_index.size(1), ), dtype=edge_index.dtype,
                             device=edge_index.device)
    # print('edge_weight was None originally')

after self.__subgraph__() is called in line 261.

Please let me know if this is incorrect and that I have misunderstood something 😃 btw, I did not try modifying the explain_graph() function but I suppose it should be a similar case.

Thanks a lot.

Environment

  • PyG version: 2.0.2
  • PyTorch version: 1.8.0
  • OS: MacOS
  • Python version: 3.8
  • How you installed PyTorch and PyG (conda, pip, source): pip

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:5 (2 by maintainers)

github_iconTop GitHub Comments

1reaction
Gori-LVcommented, Feb 26, 2022

thanks a lot!

1reaction
rusty1scommented, Feb 22, 2022

Yes, that is correct. You need to take care of setting the correct mask for each individual type. Otherwise, everything else should stay the same.

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to Explain Graph Neural Network — GNNExplainer
The learning of the minimal graph Gs is by learning a mask for edges and a mask for features. That is, for each...
Read more >
HARD MASKING FOR EXPLAINING GRAPH NEURAL ...
Since GNNEXPLAINER only returns soft edge mask, we sorted them and added both nodes from the highest-ranked edges until at least five nodes...
Read more >
GNNExplainer: Generating Explanations for Graph Neural ...
recursively passing neural messages along edges of the input graph. ... GNNEXPLAINER also learns a feature mask that masks out unimportant node features ......
Read more >
Explainability of Graph Neural Network | by Renu Khandelwal
GNNExplainer learns soft masks for edges and node features to explain the predictions via mask ... Implementation using GNNExplainer.
Read more >
Tutorial for GNN Explainability — DIG
Specifically, we show how to implement SubgraphX 2 to provide subgraph explanations to ... If we study graph edges, we can use GNNExplainer,...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found