BatchNorm tensor size issue
See original GitHub issueš Bug
Iāve been running a number of comparisons between reference implementations of PointNet++ and DGCNN. What Iāve noticed is that the PyG implementation is considerably slower than the original works. Iāve timed each step and found that PyG is faster in many of the steps (e.g., FPS, knn-graph, radius-graph, scatter functions), but is much slower when applying the MLP. Iāve found that this is due to the tensor shape that is provided to the torch.nn.BatchNorm1d
function. BatchNorm1d expects a Tensor of size: (N, C, L) or (N, C), where N is the number of batches, C is the number of channels and L is the size of each image/graph. By default, PyG feeds a tensor of size (N * L, C) to the BatchNorm module.
Iāve written my own BatchNorm1d wrapper that first reshapes the (N * L, C) tensor into a (N, C, L) tensor and then feeds it to the BatchNorm module. This speeds up the computation significantly:
class GraphBatchNorm1d(torch.nn.Module):
r"""Applies batch normalization to a graph signal.
"""
def __init__(self, in_channels, num_points=None, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super(GraphBatchNorm1d, self).__init__()
self.num_points = num_points
self.bn = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,
track_running_stats)
self.reset_parameters()
def reset_parameters(self):
self.bn.reset_parameters()
def forward(self, x: Tensor) -> Tensor:
if self.num_points is None:
return self.bn(x)
sh = x.size()
x = x.view(-1, self.num_points, sh[1]).transpose(2, 1)
x = self.bn(x)
return x.transpose(2, 1).reshape(sh).contiguous()
def __repr__(self):
return f'{self.__class__.__name__}({self.bn.num_features})'
I was wondering: have others noticed this problem as well? Are there better solutions to this?
One challenge with PointNet++ is that it works on radius graphs, which means the number of neighbors is different for each point. The consequence is that x_j
and x_i
cannot be reshaped into a ārectangularā size. Other implementations solve this by truncating the neighborhoods to a maximum of k neighbors and repeating the center point for neighboorhods with < k neighbors.
A sub-optimal implementation of that solution:
def batched_radius(pos_x, pos_y, r, batch_x, batch_y, max_num_neighbors=32, sample_idx=None):
row, col = knn(pos_x, pos_y, max_num_neighbors, batch_x, batch_y)
radius_squared = (pos_y[row] - pos_x[col]).pow(2).sum(dim=-1)
# Replace edges to points further than r with self-loops
# If pos_y is a sub-graph of pos_x, sampled with sample_idx, map indices correctly
self_loop = row
if sample_idx is not None:
idx_lookup = torch.arange(pos_x.size(0), device=pos_x.device)[sample_idx]
self_loop = idx_lookup[self_loop]
col = torch.where(radius_squared > (r ** 2), self_loop, col)
return row, col
This solution can be improved by using torch_clusterās radius
function (itās faster than knn) and adding self loops to the result.
Issue Analytics
- State:
- Created 2 years ago
- Comments:5 (4 by maintainers)
Top GitHub Comments
I personally think this is a bug in PyTorch. Iām reaching out to the PyTorch team to get some more information.
This issue can be tracked here: https://github.com/pytorch/pytorch/issues/38915