Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

`create_feature_extractor` V.S. `_utils.IntermediateLayerGetter`

See original GitHub issue

🐛 Describe the bug

When I used resnet50 in torchvision.models, I replaced BatchNorm2d with FrozenBatchNorm2d. Then I needed to extract the output of layer4, but I found that create_feature_extractor does not display FrozenBatchNorm2d. Then I used IntermediateLayerGetter and it displayed it properly. As shown below:

import torch
from torchvision import models
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models._utils import IntermediateLayerGetter

print(torch.__version__)  # 1.10.2+cu113

class FrozenBatchNorm2d(torch.nn.Module):
    BatchNorm2d where the batch statistics and the affine parameters are fixed.

    Copy-paste from torchvision.misc.ops with added eps before rqsrt,
    without which any other models than torchvision.models.resnet[18,34,50,101]
    produce nans.

    def __init__(self, n):
        super(FrozenBatchNorm2d, self).__init__()
        self.register_buffer("weight", torch.ones(n))
        self.register_buffer("bias", torch.zeros(n))
        self.register_buffer("running_mean", torch.zeros(n))
        self.register_buffer("running_var", torch.ones(n))

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        num_batches_tracked_key = prefix + 'num_batches_tracked'
        if num_batches_tracked_key in state_dict:
            del state_dict[num_batches_tracked_key]

        super(FrozenBatchNorm2d, self)._load_from_state_dict(
            state_dict, prefix, local_metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def forward(self, x):
        # move reshapes to the beginning
        # to make it fuser-friendly
        w = self.weight.reshape(1, -1, 1, 1)
        b = self.bias.reshape(1, -1, 1, 1)
        rv = self.running_var.reshape(1, -1, 1, 1)
        rm = self.running_mean.reshape(1, -1, 1, 1)
        eps = 1e-5
        scale = w * (rv + eps).rsqrt()
        bias = b - rm * scale
        return x * scale + bias

model = models.resnet50(pretrained=True, norm_layer=FrozenBatchNorm2d)
model1 = create_feature_extractor(model, return_nodes={'layer4': '0'})
model2 = IntermediateLayerGetter(model, return_layers={'layer4': '0'})

# model1.load_state_dict(model2.state_dict())  # error !


The partial output of create_feature_extractor is as follows:

  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Module(
    (0): Module(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (downsample): Module(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)

The partial output of IntermediateLayerGetter is as follows

(layer4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): FrozenBatchNorm2d()
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn2): FrozenBatchNorm2d()
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): FrozenBatchNorm2d()
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): FrozenBatchNorm2d()

So, why does create_feature_extractor not work as expected. And, I wonder if there is a built-in function with the same functionality as FrozenBatchNorm2d above.


PyTorch version: 1.10.2+cu113 Is debug build: False CUDA used to build PyTorch: 11.3 ROCM used to build PyTorch: N/A

OS: Microsoft Windows 10 专业版 GCC version: Could not collect Clang version: Could not collect CMake version: Could not collect Libc version: N/A

Python version: 3.8.5 (default, Sep 3 2020, 21:29:08) [MSC v.1916 64 bit (AMD64)] (64-bit runtime) Python platform: Windows-10-10.0.19041-SP0 Is CUDA available: True CUDA runtime version: 11.1.74 GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

Versions of relevant libraries: [pip3] mypy-extensions==0.4.3 [pip3] numpy==1.22.1 [pip3] numpydoc==1.1.0 [pip3] torch==1.10.2+cu113 [pip3] torch-tb-profiler==0.2.0 [pip3] torchaudio==0.10.2+cu113 [pip3] torchfile==0.1.0 [pip3] torchsummary==1.5.1 [pip3] torchvision==0.11.3+cu113 [pip3] torchviz==0.0.2 [conda] blas 1.0 mkl [conda] cudatoolkit 11.1.1 heb2d755_9 conda-forge [conda] mkl 2020.2 256 [conda] mkl-service 2.3.0 py38hb782905_0 [conda] mkl_fft 1.2.0 py38h45dec08_0 [conda] mkl_random 1.1.1 py38h47e9c7a_0 [conda] mypy-extensions 0.4.3 pypi_0 pypi [conda] numpy 1.22.1 pypi_0 pypi [conda] numpydoc 1.1.0 pyhd3eb1b0_1 [conda] torch 1.10.2+cu113 pypi_0 pypi [conda] torch-tb-profiler 0.2.0 pypi_0 pypi [conda] torchaudio 0.10.2+cu113 pypi_0 pypi [conda] torchfile 0.1.0 pypi_0 pypi [conda] torchsummary 1.5.1 pypi_0 pypi [conda] torchvision 0.11.3+cu113 pypi_0 pypi [conda] torchviz 0.0.2 pypi_0 pypi

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:10 (8 by maintainers)

github_iconTop GitHub Comments

jdsgomescommented, Mar 3, 2022

It is indeed in my radar @datumbox. @alexander-soare thanks for the clarification. I am still leaning towards making it mandatory, and if we later decide otherwise (people come up with use cases for that), we can change the behaviour in a non BC-breaking way by introducing a default flag for instance.

jdsgomescommented, Mar 2, 2022

@datumbox That is a good point, I can see that is indeed the case in Pytorch core, so we can align our approach and make it mandatory by always. I don’t see a good reason for skipping the default auto-wraps.

Read more comments on GitHub >

github_iconTop Results From Across the Web

create_feature_extractor | The Search Engine You Control
Creates a new graph module that returns intermediate nodes from a given model as dictionary with user specified keys as strings, and the...
Read more >
IntermediateLayerGetter parameters - PyTorch Forums
hi, in the following example, i expect that the number of parameters of model new_m to be the same as m but it...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Post

No results found

github_iconTop Related Hashnode Post

No results found