Add info on retinanet finetune to docs.
See original GitHub issueIn the object detection tutorial, there is a helpful section on the objects needed to alter the number of classes. Can we confirm the corresponding objects for retinanet and I can submit a PR?
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2 # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
I think it should go just before the commented section on changing the backbone here: https://pytorch.org/vision/stable/_modules/torchvision/models/detection/retinanet.html.
Its something like
import torchvision
from torchvision.models.detection.retinanet import RetinaNet
from torchvision.models.detection.retinanet import AnchorGenerator
def load_backbone():
backbone = torchvision.models.detection.retinanet_resnet50_fpn(
pretrained=True)
# load the model onto the computation device
return backbone
def create_anchor_generator():
# let's make the network generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios
#Documented https://github.com/pytorch/vision/blob/67b25288ca202d027e8b06e17111f1bcebd2046c/torchvision/models/detection/anchor_utils.py#L9
anchor_generator = AnchorGenerator(sizes=((8,16,32,64,128,256,400),),aspect_ratios=((0.5, 1.0, 2.0),))
return anchor_generator
def create_model(num_classes):
backbone = load_backbone()
anchor_generator = create_anchor_generator()
model = RetinaNet(backbone.backbone, num_classes=num_classes, anchor_generator=anchor_generator)
return model
yields
File "/Users/benweinstein/opt/miniconda3/envs/DeepForest_pytorch/lib/python3.8/site-packages/torchvision/models/detection/anchor_utils.py", line 103, in grid_anchors
assert len(grid_sizes) == len(strides) == len(cell_anchors)
@hgaiser let me know if I missed something. Is it required to replace anchor_generator? If you replace the anchor generator, does that mean you can’t use the same backbone (any internal size normalization done by retinanet may be distorted). I’m almost done with my pytorch-lightning implementation and want to leverage the pretrained model, but allow users to specify num_classes. Removing the anchor box generator seems to work, but I like to submit tests with my PR, I’m not wholly sure how to write a test for this
I would do:
retinanet_model = create_model(num_classes=2)
retinanet_model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = retinanet_model(x)
for prediction in predictions:
assert [x == 1 for x in prediction["labels"]]
but I get blank predictions. All thoughts welcome on how to write an appropriate test.
predictions
[{'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward>), 'scores': tensor([], grad_fn=<CatBackward>), 'labels': tensor([], dtype=torch.int64)}, {'boxes': tensor([], size=(0, 4), grad_fn=<StackBackward>), 'scores': tensor([], grad_fn=<CatBackward>), 'labels': tensor([], dtype=torch.int64)}]
Issue Analytics
- State:
- Created 3 years ago
- Comments:10 (8 by maintainers)
Yeah unfortunately you can’t load only a partial state dict. There is a PR for this but it doesn’t seem like it will be merged. I posted there how I resolved my issue for now, which should also help with your issue.
https://github.com/pytorch/pytorch/pull/39144#issuecomment-784560497
ugh. I just noticed that it is documented, directly below the link I pasted. The backbone takes a num_class argument. My apologies.