Error in `scatter_sum` with pre-transform that adds a node to the graph
See original GitHub issue❓ Questions & Help
Hi!
I get an error I really don’t understand in scatter_sum
that only happens with a new transform that I made.
ERROR:
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/Users/linda/miniconda2/envs/GraphNeuralNetworkExperiments/lib/python3.8/site-packages/torch_scatter/scatter.py", line 12, in scatter_sum
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None) -> torch.Tensor:
index = broadcast(index, src, dim)
~~~~~~~~~ <--- HERE
if out is None:
size = list(src.size())
File "/Users/linda/miniconda2/envs/GraphNeuralNetworkExperiments/lib/python3.8/site-packages/torch_scatter/utils.py", line 13, in broadcast
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
~~~~~~~~~~~~~ <--- HERE
return src
RuntimeError: The expanded size of the tensor (212) must match the existing size (196) at non-singleton dimension 0. Target sizes: [212, 128]. Tensor sizes: [196, 1]
PRE-TRANSFORM:
class CentralNode(object):
def __call__(self, data):
# add the central node
central_node_features = torch.mean(data.x, 0)
data.x = torch.cat((data.x, central_node_features.unsqueeze(0)), 0)
# add edges to all other nodes
additional_edges = torch.tensor([list(range(data.num_nodes)), [data.num_nodes]*data.num_nodes])
data.edge_index = torch.cat((data.edge_index, additional_edges), 1)
# give it a position
if data.pos is not None:
central_node_pos = torch.mean(data.pos, 0)
data.pos = torch.cat((data.pos, central_node_pos.unsqueeze(0)), 0)
return data
def __repr__(self):
return '{}()'.format(self.__class__.__name__)
Without the pre-transform, everything works fine. Any help is appreciated 😃! I
Issue Analytics
- State:
- Created 3 years ago
- Comments:14 (7 by maintainers)
Top Results From Across the Web
Error in `scatter_sum` with pre-transform that adds a node to ...
Questions & Help Hi! I get an error I really don't understand in scatter_sum that only happens with a new transform that I...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
It looks like this may be a problem with your data. Please ensure that
edge_index.max()
is lower thanx.size(0)
.Thank you for your explanation. For this, I could solve the issue.