[RFC] Use `pretrained=True` to load the best available pre-trained weights
See original GitHub issueš RFC
Background Info
To access pre-trained models in TorchVision, one needs to pass pretrained=True
on the model builders. Example:
from torchvision.models import resnet50
# With weights:
model = resnet50(pretrained=True)
# Without weights:
model = resnet50(pretrained=False)
Unfortunately the above API does not allow us to support multiple pre-trained weights. This feature is necessary when we want to provide improved weights on the same dataset (for example better Acc@1 on ImageNet) or additional weights trained on a different dataset (for example in Object Detection use VOC instead of COCO). With the completion of the Multi-weight support prototype the TorchVision model builders can now support more than 1 set of weights:
from torchvision.prototype.models import resnet50, ResNet50_Weights
# Old weights:
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V1)
# New weights:
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V2)
# No weights:
model = resnet50(weights=None)
The above prototype API is now available on nightly builds where users can test it and gather feedback. Once the feedback is gathered and acted upon, we will consider releasing the new API on the main area.
What should be the behaviour of pretrained=True
?
Upon release, the legacy pretrained=True
parameter will be deprecated and it will be removed on a future version of TorchVision (TBD when). The question of this RFC is what the behaviour of the pretrained=True
should be until its removal. There are currently two obvious candidates:
Option 1: Using the Legacy weights
Using pretrained=True
the new API should return the same legacy weights as the one used by the current API.
This is how the prototype is currently implemented. The following calls are all equivalent:
# Legacy weights with accuracy 76.130%
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V1)
model = resnet50(pretrained=True)
model = resnet50(True)
Why to select this option:
- It is aligned with TorchVisionās strong Backwards Compatibility guarantees
- It requires a āmanual opt-inā from users to switch to the new weights
- Itās the safest option
Option 2: Using the Best available weights
Using pretrained=True
the new API should return the best available weights.
The following calls will be made equivalent:
# New weights with accuracy 80.674%
model = resnet50(weights=ResNet50_Weights.ImageNet1K_V2)
model = resnet50(weights=ResNet50_Weights.default)
model = resnet50(pretrained=True)
model = resnet50(True)
Why to select this option:
- The users will benefit automatically from the major accuracy improvement.
- In practice TorchVision didnāt actually offer BC guarantees on the weights. There are several instances where we modified in-place the weights previously [1, 2, 3, 4]. Due to this, one could make the argument that the semantics of
pretrained=True
always meant āgive me the current best weightsā. - The in-place modification of weights is commonplace in other libraries [1, 2].
- It emphasises the fact that ResNet50 and other older architectures achieve very high accuracies if trained with modern approaches and this can have positive influence in research [1].
To address some of the cons of adopting this option we will:
- Raise warnings to inform users that they access new weights. Provide information within the warning on how to switch to the old behaviour.
- Inform downstream libraries and users about the upcoming change via blogposts, social media and even by opening PRs to their projects (especially for Meta-backed projects).
Feedback
We would love to hear your thoughts on the matter. You can vote on this topic on Twitter or explain your position in the comments.
Issue Analytics
- State:
- Created 2 years ago
- Reactions:5
- Comments:7 (3 by maintainers)
Top GitHub Comments
I have quite often needed to precompute features using the torchvision resnets (such as for indexing with FAISS) and relied on being able to get a comparable feature for new images by just creating a new
resnet18(pretrained=True)
so this sort of use case would break quite badly if the weights changed. (I would at least want loud warnings so that incompatible features arenāt silently added to an index after upgrading torchvision).I lean towards option 1.
The reason being that paper implementations donāt generally get updated once they are released to the public. Reproducing research is fundamental, and option 2 provides a silent breakage in BC that can make reproducing the results of the paper not possible, without any warnings to the user of why (and raising warnings all the time is very annoying).
In the same way, if a downstream user is on an older version of torchvision (say 0.11), and is pulling a repository which used torchvision 0.13 (with
pretrained=True
meaning get the new weights), they will also not be able to reproduce the results, adding one extra layer of complication of why they werenāt able to reproduce the results.