question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Add info on retinanet finetune to docs.

See original GitHub issue

In the object detection tutorial, there is a helpful section on the objects needed to alter the number of classes. Can we confirm the corresponding objects for retinanet and I can submit a PR?

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

I think it should go just before the commented section on changing the backbone here: https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html.

Its something like

import torchvision
from torchvision.models.detection.retinanet import RetinaNet
from torchvision.models.detection.retinanet import AnchorGenerator

def load_backbone():
    backbone = torchvision.models.detection.retinanet_resnet50_fpn(
        pretrained=True)

    # load the model onto the computation device
    return backbone

def create_anchor_generator():
     # let's make the network generate 5 x 3 anchors per spatial
     # location, with 5 different sizes and 3 different aspect
     # ratios. We have a Tuple[Tuple[int]] because each feature
     # map could potentially have different sizes and
     # aspect ratios
    #Documented https://github.com/pytorch/vision/blob/67b25288ca202d027e8b06e17111f1bcebd2046c/torchvision/models/detection/anchor_utils.py#L9
    
    anchor_generator = AnchorGenerator(sizes=((8,16,32,64,128,256,400),),aspect_ratios=((0.5, 1.0, 2.0),))
    
    return anchor_generator
    
def create_model(num_classes):
    backbone = load_backbone()
    anchor_generator = create_anchor_generator()
    
    model = RetinaNet(backbone.backbone, num_classes=num_classes, anchor_generator=anchor_generator)
    
    return model

yields

File "/Users/benweinstein/opt/miniconda3/envs/DeepForest_pytorch/lib/python3.8/site-packages/torchvision/models/detection/anchor_utils.py", line 103, in grid_anchors
  assert len(grid_sizes) == len(strides) == len(cell_anchors)

@hgaiser let me know if I missed something. Is it required to replace anchor_generator? If you replace the anchor generator, does that mean you can’t use the same backbone (any internal size normalization done by retinanet may be distorted). I’m almost done with my pytorch-lightning implementation and want to leverage the pretrained model, but allow users to specify num_classes. Removing the anchor box generator seems to work, but I like to submit tests with my PR, I’m not wholly sure how to write a test for this

I would do:

    retinanet_model = create_model(num_classes=2)
    
    retinanet_model.eval()
    x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
    predictions = retinanet_model(x)    
    for prediction in predictions:
        assert [x == 1 for x in prediction["labels"]]

but I get blank predictions. All thoughts welcome on how to write an appropriate test.

predictions
[{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward>), 'scores': tensor([], grad_fn=<CatBackward>), 'labels': tensor([], dtype=torch.int64)}, {'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward>), 'scores': tensor([], grad_fn=<CatBackward>), 'labels': tensor([], dtype=torch.int64)}]

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:10 (8 by maintainers)

github_iconTop GitHub Comments

1reaction
hgaisercommented, Feb 23, 2021

Yeah unfortunately you can’t load only a partial state dict. There is a PR for this but it doesn’t seem like it will be merged. I posted there how I resolved my issue for now, which should also help with your issue.

https://github.com/pytorch/pytorch/pull/39144#issuecomment-784560497

1reaction
bw4szcommented, Feb 23, 2021

ugh. I just noticed that it is documented, directly below the link I pasted. The backbone takes a num_class argument. My apologies.

    if pretrained:
        # no need to download the backbone if pretrained is set
        pretrained_backbone = False
    # skip P2 because it generates too many anchors (according to their paper)
    backbone = resnet_fpn_backbone('resnet50', pretrained_backbone,
                                   returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256))
    model = RetinaNet(backbone, num_classes, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
                                              progress=progress)
        model.load_state_dict(state_dict)
    return model
Read more comments on GitHub >

github_iconTop Results From Across the Web

RetinaNet — TAO Toolkit 3.22.05 documentation
The RetinaNet dataloader supports the raw KITTI formatted data as well as TFrecords. To use TFRecords for optimized iteration across the ...
Read more >
Fine tune the RetinaNet model in PyTorch
I would like to fine the pre-trained RetinaNet model available in torchvision in order to create my own object detection.
Read more >
gitHub-RetinaNet-Demo.ipynb - Colaboratory - Google Colab
In this notebook, we implement PyTorch RetinaNet for custom dataset. We will take the following steps to implement PyTorch RetinaNet on our custom...
Read more >
TorchVision Object Detection Finetuning Tutorial - PyTorch
So each image has a corresponding segmentation mask, where each color correspond to a different instance. Let's write a torch.utils.data.Dataset class for this ......
Read more >
Tutorial 11: How to xxx - MMDetection's documentation!
Suppose you want to use MobileNetV3-small as the backbone network of RetinaNet , the example config is as the following. _base_ = [...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found