`create_feature_extractor` V.S. `_utils.IntermediateLayerGetter`
See original GitHub issue🐛 Describe the bug
When I used resnet50
in torchvision.models
, I replaced BatchNorm2d
with FrozenBatchNorm2d
. Then I needed to extract the output of layer4
, but I found that create_feature_extractor
does not display FrozenBatchNorm2d
. Then I used IntermediateLayerGetter
and it displayed it properly.
As shown below:
import torch
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models._utils import IntermediateLayerGetter
print(torch.__version__) # 1.10.2+cu113
class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other models than torchvision.models.resnet[18,34,50,101]
produce nans.
"""
def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias
model = models.resnet50(pretrained=True, norm_layer=FrozenBatchNorm2d)
model1 = create_feature_extractor(model, return_nodes={'layer4': '0'})
model2 = IntermediateLayerGetter(model, return_layers={'layer4': '0'})
# model1.load_state_dict(model2.state_dict()) # error !
print(model1)
print(model2)
The partial output of create_feature_extractor
is as follows:
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Module(
(0): Module(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(downsample): Module(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
)
The partial output of IntermediateLayerGetter
is as follows
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): FrozenBatchNorm2d()
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn2): FrozenBatchNorm2d()
(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): FrozenBatchNorm2d()
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): FrozenBatchNorm2d()
)
)
So, why does create_feature_extractor
not work as expected. And, I wonder if there is a built-in function with the same functionality as FrozenBatchNorm2d
above.
Versions
PyTorch version: 1.10.2+cu113 Is debug build: False CUDA used to build PyTorch: 11.3 ROCM used to build PyTorch: N/A
OS: Microsoft Windows 10 专业版 GCC version: Could not collect Clang version: Could not collect CMake version: Could not collect Libc version: N/A
Python version: 3.8.5 (default, Sep 3 2020, 21:29:08) [MSC v.1916 64 bit (AMD64)] (64-bit runtime) Python platform: Windows-10-10.0.19041-SP0 Is CUDA available: True CUDA runtime version: 11.1.74 GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
Versions of relevant libraries: [pip3] mypy-extensions==0.4.3 [pip3] numpy==1.22.1 [pip3] numpydoc==1.1.0 [pip3] torch==1.10.2+cu113 [pip3] torch-tb-profiler==0.2.0 [pip3] torchaudio==0.10.2+cu113 [pip3] torchfile==0.1.0 [pip3] torchsummary==1.5.1 [pip3] torchvision==0.11.3+cu113 [pip3] torchviz==0.0.2 [conda] blas 1.0 mkl [conda] cudatoolkit 11.1.1 heb2d755_9 conda-forge [conda] mkl 2020.2 256 [conda] mkl-service 2.3.0 py38hb782905_0 [conda] mkl_fft 1.2.0 py38h45dec08_0 [conda] mkl_random 1.1.1 py38h47e9c7a_0 [conda] mypy-extensions 0.4.3 pypi_0 pypi [conda] numpy 1.22.1 pypi_0 pypi [conda] numpydoc 1.1.0 pyhd3eb1b0_1 [conda] torch 1.10.2+cu113 pypi_0 pypi [conda] torch-tb-profiler 0.2.0 pypi_0 pypi [conda] torchaudio 0.10.2+cu113 pypi_0 pypi [conda] torchfile 0.1.0 pypi_0 pypi [conda] torchsummary 1.5.1 pypi_0 pypi [conda] torchvision 0.11.3+cu113 pypi_0 pypi [conda] torchviz 0.0.2 pypi_0 pypi
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (8 by maintainers)
It is indeed in my radar @datumbox. @alexander-soare thanks for the clarification. I am still leaning towards making it mandatory, and if we later decide otherwise (people come up with use cases for that), we can change the behaviour in a non BC-breaking way by introducing a default flag for instance.
@datumbox That is a good point, I can see that is indeed the case in Pytorch core, so we can align our approach and make it mandatory by always. I don’t see a good reason for skipping the default auto-wraps.