RandomLinkSplit: KeyError of 'edge_index' when input HeteroData

🐛 Bug

Thank you for the great tool. RandomLinkSplit or RanodomNodeSplit is not support HeteroData yet

To Reproduce

Steps to reproduce the behavior:

import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
data = HeteroData()
data = T.ToUndirected()(data)
data = T.AddSelfLoops()(data)
data = T.NormalizeFeatures()(data)
train_data, val_data, test_data = T.RandomLinkSplit(is_undirected=True, 

Error message

KeyError                                  Traceback (most recent call last)
~/miniconda/envs/drugai/lib/python3.9/site-packages/torch_geometric/data/ in __getattr__(self, key)
     47         try:
---> 48             return self[key]
     49         except KeyError:

~/miniconda/envs/drugai/lib/python3.9/site-packages/torch_geometric/data/ in __getitem__(self, key)
     67     def __getitem__(self, key: str) -> Any:
---> 68         return self._mapping[key]

KeyError: 'edge_index'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
<ipython-input-34-458eaa1ce7eb> in <module>
     11 #  works like in the homogenous case, and normalizes all specified features (of all types) to sum up to one.
     12 data = T.NormalizeFeatures()(data)
---> 13 train_data, val_data, test_data = T.RandomLinkSplit(is_undirected=True, 
     14                                                     add_negative_train_samples=True,
     15                                                     neg_sampling_ratio=1.0)(data)

~/miniconda/envs/drugai/lib/python3.9/site-packages/torch_geometric/transforms/ in __call__(self, data)
     82     def __call__(self, data: Data) -> Tuple[Data, Data, Data]:
---> 83         perm = torch.randperm(data.num_edges, device=data.edge_index.device)
     84         if self.is_undirected:
     85             perm = perm[data.edge_index[0] <= data.edge_index[1]]

~/miniconda/envs/drugai/lib/python3.9/site-packages/torch_geometric/data/ in __getattr__(self, key)
    116             if len(out) > 0:
    117                 return out
--> 118         return getattr(self._global_store, key)
    120     def __setattr__(self, key: str, value: Any):

~/miniconda/envs/drugai/lib/python3.9/site-packages/torch_geometric/data/ in __getattr__(self, key)
     48             return self[key]
     49         except KeyError:
---> 50             raise AttributeError(
     51                 f"'{self.__class__.__name__}' object has no attribute '{key}'")

AttributeError: 'BaseStorage' object has no attribute 'edge_index'

Expected behavior

RandomLinkSplit works for heteroData


  • OS: Linux
  • Python version: 3.9
  • PyTorch version: 1.9
  • CUDA/cuDNN version: None

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

rusty1scommented, Sep 21, 2021

This is now available in master.

rusty1scommented, Sep 17, 2021
