there is no need to rewrite the 'class LayerNorm(nn.Module)'
See original GitHub issueThe reason to rewrite the ‘class LayerNorm(nn.Module)’ is that you think the layer normal provided by PyTorch only supports ‘channels_last’ format (batch_size, height, width, channels), so you rewrite a new way to support ‘channels_first’ format (batch_size, channels, height, width). However, I found the F.layer_norm or nn.LayerNorm do not require the order of channels, height and width. Because F.layer_norm will derive the calculated dimensions from the last dim using ‘normalized_shape’ to calculate the mean and variance.
Specifically, the PyTorch implementation uses the every value in a image to calculate a pair of mean and variance, and every value in the image use this two numbers to do LayerNorm. But your implementation uses the values over channels in every spatial point to get a pair of mean and variance in every spatial point.
When I changed the following codes in convnext.py, I found I do the same thing as ‘F.layer_norm’ or ‘nn.LayerNorm’ by PyTorch. https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119
u = a.mean([1, 2, 3], keepdim=True)
# u = x.mean(1, keepdim=True) # original code
s = (x - u).pow(2).mean([1, 2, 3], keepdim=True)
# s = (x - u).pow(2).mean(1, keepdim=True) # original code
x = self.weight[None, :] * x + self.bias[None, :]
# x = self.weight[:, None, None] * x + self.bias[:, None, None] # original code
There is no need to rewrite the ‘class LayerNorm(nn.Module)’, it’s just a misunderstanding about LayerNorm implementation.
Issue Analytics
- State:
- Created a year ago
- Comments:5 (2 by maintainers)
Top GitHub Comments
FYI, LayerNorm paper’s section 6.7 talks about CNNs. Although it does not clearly say how it is applied to (N, C, H, W), the words does have some hints:
My reading of it is that the “original” LayerNorm does normalize over (C, H, W) (and they think this might not be a good idea).
Although today in Transformer’s point of view, H and W becomes “sequence” and then it becomes natural to normalize only on C dimension. And btw, “positional normalization” https://arxiv.org/pdf/1907.04312.pdf seem to be the first one to formally name such an operation for CNN.
If I understand correctly, normalizing all C,H,W dimensions is equivalent to a GroupNorm with #groups=1. We haven’t got a chance to try this though. The Poolformer paper uses this as their default