Cannot train model with custom HeteroData dataset.
See original GitHub issueš Describe the bug
Hi, first of all, thanks for the amazing work you do with PyG!
Iām new to graphs and Iām trying to create a custom HeteroData dataset. So far, so good; the creation of the dataset seems to work as expected. I have taken OGB_MAG dataset as reference, since the examples I saw use that one.
So, at the moment I am able to access my dataset object:
>>> dataset[0]
HeteroData(
foo={
x=[18, 100],
y=[18, 1],
train_mask=[18],
test_mask=[18]
},
bar={ x=[6, 100] },
(foo, belongs_to, bar)={ edge_index=[18, 2] },
(bar, connects_to, bar)={ edge_index=[4, 2] }
)
I know itās a small dataset object, but just for the sake of reducing processing time, I decided to get a small subset of the original dataset. So, basically there are nodes of type foo which belong/are connected to nodes of type bar, so nodes of type foo are like āend nodesā whose edges are always connected to a node of type bar. On the other hand, nodes of type bar are entities that can be connected between themselves.
The issue Iām facing is when trying to train the model with this particular graph. Iām training the model using the example provided here; more precisely the class GNN.
However, the error Iām experiencing is the following:
Traceback (most recent call last):
File "main.py", line 66, in <module>
out = model(data.x_dict, data.edge_index_dict)
File "/usr/local/lib/python3.8/site-packages/torch/fx/graph_module.py", line 616, in wrapped_call
raise e.with_traceback(None)
AssertionError
I know it looks very āgeneralā and doesnāt have that much info, but I hope anyone has experienced something similar and is willing to give me a hand.
Thanks in advance!
Environment
- PyG version: 2.0.3
- PyTorch version: 1.10.2
- OS: macOS Monterey Version 12.2.1
- Python version: 3.8.12
- CUDA/cuDNN version: N/A
- How you installed PyTorch and PyG (
conda
,pip
, source): pip - Any other relevant information (e.g., version of
torch-scatter
): N/A
Issue Analytics
- State:
- Created 2 years ago
- Comments:5
Top GitHub Comments
Never mind, I already found the issue for this particular error. Tensorās size was not correct.
Alright Thanks