question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Error initializing fasterrcnn_resnet50_fpn with num_classes and pretrained

See original GitHub issue

🐛 Bug

In torchvision.models.detection.fasterrcnn_resent50_fpn there are two parameters, num_classes and pretrained.

If I pass just one of them, I am able to initialize the model

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(num_classes=6)

or

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True)

^this obv won’t work for me since I have just 5 classes and not 91.

hence, when I pass both, I get this error

code:

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True, num_classes=6)

error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
/var/folders/5q/5dh5_83x359fqcb949hhtj8m0000gp/T/ipykernel_11119/3787706091.py in <module>
----> 1 model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True, num_classes=6)
      2 # num_classes = 6 # should be initialized as target_col.nunique + 1
      3 # in_features = model.roi_heads.box_predictor.cls_score.in_features
      4 # model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    370         state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
    371                                               progress=progress)
--> 372         model.load_state_dict(state_dict)
    373         overwrite_eps(model, 0.0)
    374     return model

   1404 
   1405         if len(error_msgs) > 0:
-> 1406             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1407                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1408         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for FasterRCNN:
	size mismatch for roi_heads.box_predictor.cls_score.weight: copying a param with shape torch.Size([91, 1024]) from checkpoint, the shape in current model is torch.Size([6, 1024]).
	size mismatch for roi_heads.box_predictor.cls_score.bias: copying a param with shape torch.Size([91]) from checkpoint, the shape in current model is torch.Size([6]).
	size mismatch for roi_heads.box_predictor.bbox_pred.weight: copying a param with shape torch.Size([364, 1024]) from checkpoint, the shape in current model is torch.Size([24, 1024]).
	size mismatch for roi_heads.box_predictor.bbox_pred.bias: copying a param with shape torch.Size([364]) from checkpoint, the shape in current model is torch.Size([24]).

To Reproduce

Steps to reproduce the behavior:

import torchvision

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True, num_classes=6)

Expected behavior

It should not throw an error.

Environment

  • PyTorch / torchvision Version (e.g., 1.0 / 0.4.0):
torch                             1.9.0
torchvision                       0.10.0
  • OS (e.g., Linux):
MacOS
  • How you installed PyTorch / torchvision (conda, pip, source):
pip
  • Build command you used (if compiling from source):
None
  • Python version:
python3 --version
>Python 3.8.5
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:6 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
NicolasHugcommented, Aug 10, 2021

I think that the pretrained model is pretrained for a specific number of classes, which is why we cannot set both. But it would be nice to have a better error message.

0reactions
NicolasHugcommented, Aug 16, 2021

I agree it might be tricky to fully validate the kwargs, but for the non-kwargs argument it should be fairly straightforward, even if a bit verbose

Read more comments on GitHub >

github_iconTop Results From Across the Web

No results found

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found