question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Customize FasterRCNN

See original GitHub issue

Hi,

I’ve been trying, unsuccessfully to customize a bit the implementation of FasterRCNN proposed by torchvision. For example, one thing I would like to do, would be to write a customized postprocess_detections function that return confidence for all labels and not only the one with highest confidence.

In the past I’ve managed to successfully overwrite the loss function by doing something like

model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
torchvision.models.detection.roi_heads.fastrcnn_loss = custom_loss

But the postprocess_detections function is within the RoIHeads class. If I try to replace the RoIHead class before defining my model I get this error:

torchvision.models.detection.roi_heads.RoIHeads = RoIHeadsCustom
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
    pretrained=True
)
Traceback (most recent call last):
  File "test2.py", line 80, in <module>
    model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
  File "/home/paul/.local/lib/python3.8/site-packages/torchvision/models/detection/faster_rcnn.py", line 470, in fasterrcnn_mobilenet_v3_large_fpn
    return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress,
  File "/home/paul/.local/lib/python3.8/site-packages/torchvision/models/detection/faster_rcnn.py", line 393, in _fasterrcnn_mobilenet_v3_large_fpn
    model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
  File "/home/paul/.local/lib/python3.8/site-packages/torchvision/models/detection/faster_rcnn.py", line 222, in __init__
    roi_heads = RoIHeads(
  File "/home/paul/.local/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py", line 512, in __init__
    super(RoIHeads, self).__init__()
TypeError: super(type, obj): obj must be an instance or subtype of type

But if I define it afterwards, the object is already created and the custom class is not taken into account

model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
    pretrained=True
)
torchvision.models.detection.roi_heads.RoIHeads = RoIHeadsCustom

If anyone has some ideas on how to easily customize torchvision models that would be a great help. The only solution I’m seeing is creating a fork of torchvision, which I’d rather avoid. Thanks.

cc @datumbox @YosuaMichael

Issue Analytics

  • State:closed
  • Created a year ago
  • Reactions:1
  • Comments:7 (5 by maintainers)

github_iconTop GitHub Comments

2reactions
YosuaMichaelcommented, Jun 30, 2022

Hi @paullixo , I think you can still do:

torchvision.models.detection.roi_heads.RoIHeads = RoIHeadsCustom
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
    pretrained=True
)

We can create the class RoIHeadsCustom that inherit RoIHeads and replace the postprocess_detections function:

class RoIHeadsCustom(torchvision.models.detection.roi_heads.RoIHeads):
    def postprocess_detections(...):
        ...
1reaction
YosuaMichaelcommented, Jul 4, 2022

Hi @paullixo , strangely enough your code work on my laptop. However I think @datumbox suggestion to initiate the roi_heads module is a better way to do it. Here is the full code using @datumbox suggestion:

from typing import Optional, List, Dict, Tuple

import torch
import torchvision

class RoIHeadsCustom(torchvision.models.detection.roi_heads.RoIHeads):
    def postprocess_detections(
        self,
        class_logits,  # type: Tensor
        box_regression,  # type: Tensor
        proposals,  # type: List[Tensor]
        image_shapes,  # type: List[Tuple[int, int]]
    ):
        # dummy function to test the customization
        print("custom postprocess function")
        return

model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
    pretrained=True
)
# assign to r just for quick access of its param 
r = model.roi_heads
# Initiate the custom roi_heads (see https://github.com/pytorch/vision/blob/main/torchvision/models/detection/roi_heads.py#L492)
new_roi_heads = RoIHeadsCustom(r.box_roi_pool, r.box_head, r.box_predictor, 
              r.proposal_matcher.high_threshold, r.proposal_matcher.low_threshold, 
              r.fg_bg_sampler.batch_size_per_image, r.fg_bg_sampler.positive_fraction,
              r.box_coder.weights, r.score_thresh, r.nms_thresh, r.detections_per_img,
              r.mask_roi_pool, r.mask_head, r.mask_predictor,
              r.keypoint_roi_pool, r.keypoint_head, r.keypoint_predictor)

model.roi_heads = new_roi_heads

basically need to copy the parameter from original roi_heads and pass it to the RoiHeadsCustom constructor. Hope this will work for you!

Read more comments on GitHub >

github_iconTop Results From Across the Web

Custom Object Detection using PyTorch Faster RCNN
Learn to carry out custom object detection using the PyTorch Faster RCNN deep learning model. A simple pipeline for training and inference.
Read more >
How to train Faster R-CNN on my own custom dataset? #1028
Hello. I want to: 1- train Faster RCNN on my own custom dataset. 2- use the pre-trained Faster RCNN on the VOC2007 as...
Read more >
Training Faster R-CNN Using TensorFlow's Object Detection ...
Step-by-step tutorial to train a faster R-CNN for object detection with TensorFlow using a custom dataset. Author(s): Buse Yaren Tekin ...
Read more >
Train your own object detector with Faster-RCNN & PyTorch
In this tutorial, however, I want to share with you my approach on how to create a custom dataset and use it to...
Read more >
Train TensorFlow Faster R-CNN Model with Custom Data
Training a TensorFlow Faster R-CNN Object Detection Model on a Custom Dataset ... Following this tutorial, you only need to change a couple...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found