Edge masking in GNNExplainer implementation
See original GitHub issueQuestion
I found that edge masking seems to be missing in implementation of GNNExplainer in PyG.
for epoch in range(1, self.epochs + 1):
optimizer.zero_grad()
h = x * self.node_feat_mask.sigmoid()
out = self.model(x=h, edge_index=edge_index, **kwargs)
if self.return_type == 'regression':
loss = self.__loss__(mapping, out, prediction)
else:
log_logits = self.__to_log_prob__(out)
loss = self.__loss__(mapping, log_logits, pred_label)
loss.backward()
optimizer.step()
This is the loop of explaining a node, line 286-298 in the source code. If my understanding is correct, only node features are masked by h = x * self.node_feat_mask.sigmoid()
, while edge_index, i.e., the adjacency matrix is not.
In the official implementation by the author of GNNExplainer, it is done by the _masked_adj()
function. Due to the difference in data structure, I modify the code as below to do edge masking:
for epoch in range(1, self.epochs + 1):
optimizer.zero_grad()
h = x * self.node_feat_mask.sigmoid()
### fix
edge_weight = ori_edge_weight * self.edge_mask.sigmoid()
out = self.model(x=h, edge_index=edge_index, edge_weight = edge_weight)
###
# out = self.model(x=h, edge_index=edge_index, **kwargs)
if self.return_type == 'regression':
loss = self.__loss__(mapping, out, prediction)
else:
log_logits = self.__to_log_prob__(out)
loss = self.__loss__(mapping, log_logits, pred_label)
loss.backward()
optimizer.step()
and add
if 'edge_weight' in kwargs:
# print(kwargs['edge_weight'].shape)
ori_edge_weight = kwargs['edge_weight']
# print('edge_weight is not None')
else:
ori_edge_weight = torch.ones((edge_index.size(1), ), dtype=edge_index.dtype,
device=edge_index.device)
# print('edge_weight was None originally')
after self.__subgraph__()
is called in line 261.
Please let me know if this is incorrect and that I have misunderstood something 😃 btw, I did not try modifying the explain_graph()
function but I suppose it should be a similar case.
Thanks a lot.
Environment
- PyG version: 2.0.2
- PyTorch version: 1.8.0
- OS: MacOS
- Python version: 3.8
- How you installed PyTorch and PyG (
conda
,pip
, source):pip
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (2 by maintainers)
Top GitHub Comments
thanks a lot!
Yes, that is correct. You need to take care of setting the correct mask for each individual type. Otherwise, everything else should stay the same.