question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. ItĀ collects links to all the places you might be looking at while hunting down a tough bug.

And, if youā€™re still stuck at the end, weā€™re happy to hop on a call to see how we can help out.

BatchNorm tensor size issue

See original GitHub issue

šŸ› Bug

Iā€™ve been running a number of comparisons between reference implementations of PointNet++ and DGCNN. What Iā€™ve noticed is that the PyG implementation is considerably slower than the original works. Iā€™ve timed each step and found that PyG is faster in many of the steps (e.g., FPS, knn-graph, radius-graph, scatter functions), but is much slower when applying the MLP. Iā€™ve found that this is due to the tensor shape that is provided to the torch.nn.BatchNorm1d function. BatchNorm1d expects a Tensor of size: (N, C, L) or (N, C), where N is the number of batches, C is the number of channels and L is the size of each image/graph. By default, PyG feeds a tensor of size (N * L, C) to the BatchNorm module.

Iā€™ve written my own BatchNorm1d wrapper that first reshapes the (N * L, C) tensor into a (N, C, L) tensor and then feeds it to the BatchNorm module. This speeds up the computation significantly:

class GraphBatchNorm1d(torch.nn.Module):
    r"""Applies batch normalization to a graph signal.
    """
    def __init__(self, in_channels, num_points=None, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(GraphBatchNorm1d, self).__init__()
        self.num_points = num_points
        self.bn = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
                              track_running_stats)
        self.reset_parameters()


    def reset_parameters(self):
        self.bn.reset_parameters()
        

    def forward(self, x: Tensor) -> Tensor:
        if self.num_points is None:
            return self.bn(x)

        sh = x.size()
        x = x.view(-1, self.num_points, sh[1]).transpose(2, 1)
        x = self.bn(x)
        return x.transpose(2, 1).reshape(sh).contiguous()


    def __repr__(self):
        return f'{self.__class__.__name__}({self.bn.num_features})'

I was wondering: have others noticed this problem as well? Are there better solutions to this?

One challenge with PointNet++ is that it works on radius graphs, which means the number of neighbors is different for each point. The consequence is that x_j and x_i cannot be reshaped into a ā€˜rectangularā€™ size. Other implementations solve this by truncating the neighborhoods to a maximum of k neighbors and repeating the center point for neighboorhods with < k neighbors.

A sub-optimal implementation of that solution:

def batched_radius(pos_x, pos_y, r, batch_x, batch_y, max_num_neighbors=32, sample_idx=None):
    row, col = knn(pos_x, pos_y, max_num_neighbors, batch_x, batch_y)
    radius_squared = (pos_y[row] - pos_x[col]).pow(2).sum(dim=-1)
    
    # Replace edges to points further than r with self-loops
    # If pos_y is a sub-graph of pos_x, sampled with sample_idx, map indices correctly
    self_loop = row
    if sample_idx is not None:
        idx_lookup = torch.arange(pos_x.size(0), device=pos_x.device)[sample_idx]
        self_loop = idx_lookup[self_loop]
    col = torch.where(radius_squared > (r ** 2), self_loop, col)
    return row, col

This solution can be improved by using torch_clusterā€™s radius function (itā€™s faster than knn) and adding self loops to the result.

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
rusty1scommented, Apr 13, 2021

I personally think this is a bug in PyTorch. Iā€™m reaching out to the PyTorch team to get some more information.

0reactions
rusty1scommented, Apr 13, 2021

This issue can be tracked here: https://github.com/pytorch/pytorch/issues/38915

Read more comments on GitHub >

github_iconTop Results From Across the Web

Tensorflow and Batch Normalization with Batch Size==1 ...
During training, mean and variance are computed accross the current batch, which causes problem when it is of size 1: in the 1st...
Read more >
Pitfalls of Batch Norm in TensorFlow and Sanity Checks for ...
Suppose you give batch of size one during inference and normalize using batch mean and batch variance, in that case x^=0 as Ī¼=x...
Read more >
Batchnorm for different sized samples in batch - PyTorch Forums
I use batchnorm 1d on batches which are padded to the max length of the samples ... Basically, I have a tensor padded...
Read more >
tf.nn.batch_normalization | TensorFlow v2.11.0
In the common case where the 'depth' dimension is the last dimension in the input tensor x , they may be one dimensional...
Read more >
Batch normalization - Wikipedia
Batch normalization is a method used to make training of artificial neural networks faster and more stable through normalization of the layers' inputs...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found