Customize FasterRCNN
See original GitHub issueHi,
I’ve been trying, unsuccessfully to customize a bit the implementation of FasterRCNN proposed by torchvision. For example, one thing I would like to do, would be to write a customized postprocess_detections function that return confidence for all labels and not only the one with highest confidence.
In the past I’ve managed to successfully overwrite the loss function by doing something like
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
torchvision.models.detection.roi_heads.fastrcnn_loss = custom_loss
But the postprocess_detections function is within the RoIHeads class. If I try to replace the RoIHead class before defining my model I get this error:
torchvision.models.detection.roi_heads.RoIHeads = RoIHeadsCustom
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
pretrained=True
)
Traceback (most recent call last):
File "test2.py", line 80, in <module>
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
File "/home/paul/.local/lib/python3.8/site-packages/torchvision/models/detection/faster_rcnn.py", line 470, in fasterrcnn_mobilenet_v3_large_fpn
return _fasterrcnn_mobilenet_v3_large_fpn(weights_name, pretrained=pretrained, progress=progress,
File "/home/paul/.local/lib/python3.8/site-packages/torchvision/models/detection/faster_rcnn.py", line 393, in _fasterrcnn_mobilenet_v3_large_fpn
model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
File "/home/paul/.local/lib/python3.8/site-packages/torchvision/models/detection/faster_rcnn.py", line 222, in __init__
roi_heads = RoIHeads(
File "/home/paul/.local/lib/python3.8/site-packages/torchvision/models/detection/roi_heads.py", line 512, in __init__
super(RoIHeads, self).__init__()
TypeError: super(type, obj): obj must be an instance or subtype of type
But if I define it afterwards, the object is already created and the custom class is not taken into account
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(
pretrained=True
)
torchvision.models.detection.roi_heads.RoIHeads = RoIHeadsCustom
If anyone has some ideas on how to easily customize torchvision models that would be a great help. The only solution I’m seeing is creating a fork of torchvision, which I’d rather avoid. Thanks.
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:7 (5 by maintainers)
Hi @paullixo , I think you can still do:
We can create the class
RoIHeadsCustom
that inheritRoIHeads
and replace the postprocess_detections function:Hi @paullixo , strangely enough your code work on my laptop. However I think @datumbox suggestion to initiate the roi_heads module is a better way to do it. Here is the full code using @datumbox suggestion:
basically need to copy the parameter from original roi_heads and pass it to the RoiHeadsCustom constructor. Hope this will work for you!