question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Problems training Faster-RCNN from pretrained backbone

See original GitHub issue

Is there any recommendation to train Faster-RCNN starting from the pretrained backbone? I’m using VOC 2007 dataset and I’m able to do transfer learning starting from:

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes=21)

Using COCO pretrained ‘fasterrcnn_resnet50_fpn’ i’m able to obtain an mAP of 79% on VOC 2007 test set. Problems arise when i try to train from scratch using only the pretrained backbone:

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes=21)

I have been trying to train this model for weeks but the highest mAP i got was 63% (again on test set).

Now, i know that training from scratch is harder, but i really would like to know how to set the training parameters to obtain a decent accuracy, in the future i may want to change the backbone and chances are that i will be not able to find a pretrained faster-rcnn on which i can do transfer learning.

Issue Analytics

  • State:closed
  • Created 4 years ago
  • Reactions:1
  • Comments:44 (16 by maintainers)

github_iconTop GitHub Comments

5reactions
lpugliacommented, Aug 2, 2019

@hktxt FYI i can get easily 72% mAP using the example provided in FasterRCNN source code using mobilenet_v2 as backbone:

    backbone = torchvision.models.mobilenet_v2(pretrained=True).features
    backbone.out_channels = 1280
    anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                       aspect_ratios=((0.5, 1.0, 2.0),))
    roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                    output_size=7,
                                                    sampling_ratio=2)
    model = torchvision.models.detection.faster_rcnn.FasterRCNN(backbone,
                       num_classes=21,
                       rpn_anchor_generator=anchor_generator,
                       box_roi_pool=roi_pooler)

no need to modify the BoxHead.

5reactions
lpugliacommented, Jul 26, 2019

@fmassa I found out what my main problem was, I was using the val set for validation only. However, to get good result on PASCAL VOC 2007 you are supposed to use trainval all together. Also, thanks to @hktxt comment I got 66% accuracy training from scratch (just 3% less than the expected). If anyone is intereseted here the highlights:

Backbone

        vgg = torchvision.models.vgg16(pretrained=True)
        backbone = vgg.features[:-1]
        for layer in backbone[:10]:
            for p in layer.parameters():
                p.requires_grad = False
        backbone.out_channels = 512
        anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                           aspect_ratios=((0.5, 1.0, 2.0),))
        roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                        output_size=7,
                                                        sampling_ratio=2)

        class BoxHead(nn.Module):
            def __init__(self, vgg):
                super(BoxHead, self).__init__()
                self.classifier = nn.Sequential(*list(vgg.classifier._modules.values())[:-1])

            def forward(self, x):
                x = x.flatten(start_dim=1)
                x = self.classifier(x)
                return x
        box_head = BoxHead(vgg)

Model

        model = torchvision.models.detection.faster_rcnn.FasterRCNN(
            backbone, #num_classes,
            rpn_anchor_generator = anchor_generator,
            box_roi_pool = roi_pooler,
            box_head = box_head,
            box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(4096, num_classes=21))

Dataset

dataset = VOCDetection(img_folder=root, year='2007', image_set='trainval', transforms=transforms)

The only aumentation i used was RandomHorizontalFlip.

Parameters

--epochs 40
--lr-steps 30
--momentum 0.9
--lr-gamma 0.1
Read more comments on GitHub >

github_iconTop Results From Across the Web

Troubles Training a Faster R-CNN RPN using a Resnet 101 ...
I am using a pretrained Resnet 101 backbone with three layers popped off. The popped off layers are the conv5_x layer, average pooling...
Read more >
Error when trying to train FasterRCNN with custom backbone ...
put the pieces together inside a FasterRCNN model model = FasterRCNN(backbone, num_classes=2, rpn_anchor_generator=anchor_generator, ...
Read more >
Using Any Torchvision Pretrained Model as Backbone for ...
In this post, you will learn how to use any Torchvision pretrained model as a backbone for PyTorch Faster RCNN object detector.
Read more >
Train your own object detector with Faster-RCNN & PyTorch
Taking a look at the provided functions in torchvision, we see that we can easily build a Faster R-CNN model with a pretrained...
Read more >
Faster R-CNN — Torchvision main documentation - PyTorch
The following model builders can be used to instantiate a Faster R-CNN model, with or without pre-trained weights. All the model builders internally...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found