can we collate data and save to disk when overriding __inc__?
See original GitHub issue❓ Questions & Help
Hi, thanks for the great library! I have a question. I am making customized data class just like the BipartiteData you mentioned here:https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html. However, I notice in the examples you didn’t call:
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
Since my dataset is large, I prefer to make it once and save it. However, I failed to run self.collate. My customized Data class is like this
class CorrData(Data):
def __init__(self, vtx, pts, edge_vtx, edge_pts, corr, name):
super(CorrData, self).__init__()
self.vtx = vtx
self.pts = pts
self.edge_index_vtx = edge_vtx
self.edge_index_pts = edge_pts
self.corr = corr
self.name = name
def __inc__(self, key, value):
if key == 'edge_index_vtx':
return self.vtx.size(0)
if key == 'edge_index_pts':
return self.pts.size(0)
if key == 'corr':
return torch.tensor([[self.vtx.size(0)], [self.pts.size(0)]])
else:
return super(CorrData, self).__inc__(key, value)
def __cat_dim__(self, key, value):
if 'edge_index' in key:
return 1
else:
return 0
The reported error is like this:
File “xxxx/datasets/corr_dataset.py”, line 141, in process data, slices = self.collate(data_list) File “xxxx/torch_geometric/data/in_memory_dataset.py”, line 93, in collate data = data_list[0].class() TypeError: init() missing 6 required positional arguments: ‘vtx’, ‘pts’, ‘edge_vtx’, ‘edge_pts’, ‘corr’, and ‘name’
Any suggestion about how to solve this? Thanks!
Issue Analytics
- State:
- Created 3 years ago
- Comments:5 (2 by maintainers)
Top GitHub Comments
Regarding your first issue: This was a bug introduced in PyG 1.6.0, but it is fixed in PyG 1.6.1. Sorry! Regarding your warning: Do you know where exactly this warning does occur? This shouldn’t happen either.
In general,
collate
should be usable even for customData
objects. You can fix the error by changing:to