index out of range while using to_dense_batch
See original GitHub issue🐛 Describe the bug
When I use this function, I get index out of range error
pytorch_geometric.to_dense_batch
def to_dense_batch(x, batch,
fill_value, max_num_nodes,
batch_size):
r"""Given a sparse batch of node features
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with
:math:`N_i` indicating the number of nodes in graph :math:`i`), creates a
dense node feature tensor
:math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with
:math:`N_{\max} = \max_i^B N_i`).
In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times
N_{\max}}` is returned, holding information about the existence of
fake-nodes in the dense representation.
Args:
x (Tensor): Node feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. Must be ordered. (default: :obj:`None`)
fill_value (float, optional): The value for invalid entries in the
resulting dense output tensor. (default: :obj:`0`)
max_num_nodes (int, optional): The size of the output node dimension.
(default: :obj:`None`)
batch_size (int, optional) The batch size. (default: :obj:`None`)
:rtype: (:class:`Tensor`, :class:`BoolTensor`)
"""
if batch is None and max_num_nodes is None:
mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)
return x.unsqueeze(0), mask
if batch is None:
batch = x.new_zeros(x.size(0), dtype=torch.long)
if batch_size is None:
batch_size = int(batch.max()) + 1
num_nodes = scatter_add(batch.new_ones(x.size(0)), batch, dim=0)#,
#dim_size=batch_size) # num_points_in_pillars
cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)]) # num_nodes의 축적값
if max_num_nodes is None:
max_num_nodes = int(num_nodes.max())
idx = torch.arange(batch.size(0), dtype=torch.long, device=x.device)
idx = (idx - cum_nodes[batch]) + (batch * max_num_nodes)
size = [batch_size * max_num_nodes] + list(x.size())[1:]
out = x.new_full(size, fill_value)
if idx.max() >= out.shape[0]:
import pdb;pdb.set_trace()
out[idx] = x
out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])
mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,
device=x.device)
mask[idx] = 1
mask = mask.view(batch_size, max_num_nodes)
return out, mask
idx.max() exceeds the shape of out.
this happens intermittently.
Issue Analytics
- State:
- Created 10 months ago
- Comments:9 (5 by maintainers)
Top Results From Across the Web
List Index Out of Range – Python Error Message Solved
Indexing in Python, and most modern programming languages, starts at 0. This means that the first item in a list has an index...
Read more >python at while loop IndexError: list index out of range
You're pop ing everything off the list, eventually there's nothing left and then you have an index error at 0 because the list...
Read more >IndexError: list index out of range - while fitting the PAR model ...
Python version: 3.8.9; Operating System: MAC. Error Description. I tried to synthesize event data using the sdv PAR model. However, I get an ......
Read more >How to Fix IndexError in Python - Rollbar
The IndexError in Python occurs when an item from a list is attempted to be accessed that is outside the index range of...
Read more >How to fix - List Index Out of Range in Python - GeeksforGeeks
A Computer Science portal for geeks. It contains well written, well thought and well explained computer science and programming articles, ...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Yeah, thanks for the code. I fixed it via https://github.com/pyg-team/pytorch_geometric/pull/6124.
I found that changing idx = (idx - cum_nodes[batch]) + (batch * max_num_nodes) -> idx = torch.clamp((idx - cum_nodes[batch]),max=max_num_nodes-1) + (batch * max_num_nodes) in def to_dense_batch would help when “dropping” of nodes is needed. Because it helps to drop the latest nodes that exceeds the max_num_nodes. It is useful when max_num_nodes is smaller than the maximum nodes and it also conserves the original results when max_num_nodes is larger than the maximum nodes