knn_eval() with resnet50 has missing keys in state_dict
See original GitHub issueWhile the fc
layer is not needed when extracting features from ResNet50, the following command
$ python eval_knn.py --dump_features resnet50_features --arch resnet50 --data_path imagenet1k_folder
generates this error:
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "fc.weight", "fc.bias".
Here is the complete output:
Will run the code on one GPU.
| distributed init (rank 0): env://
fatal: Not a git repository (or any parent up to mount point /home)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).
git:
sha: N/A, status: clean, branch: N/A
arch: resnet50
batch_size_per_gpu: 128
checkpoint_key: teacher
data_path: imagenet1k_folder
dist_url: env://
dump_features: resnet50_features
gpu: 0
load_features: None
local_rank: 0
nb_knn: [10, 20, 100, 200]
num_workers: 10
patch_size: 16
pretrained_weights:
rank: 0
temperature: 0.07
use_cuda: True
world_size: 1
/home/user/anaconda3/envs/vissl/lib/python3.7/site-packages/torchvision/transforms/transforms.py:258: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
"Argument interpolation should be of type InterpolationMode instead of int. "
/home/user/anaconda3/envs/vissl/lib/python3.7/site-packages/torch/utils/data/dataloader.py:477: UserWarning: This DataLoader will create 10 worker processes in total. Our suggested max number of worker in current system is 8, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
cpuset_checked))
Data loaded with 1281167 train and 50000 val imgs.
Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.
Since no pretrained weights have been provided, we load the reference pretrained DINO weights.
Traceback (most recent call last):
File "eval_knn.py", line 227, in <module>
train_features, test_features, train_labels, test_labels = extract_feature_pipeline(args)
File "eval_knn.py", line 70, in extract_feature_pipeline
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
File "/home/user/codes/dino-main/utils.py", line 107, in load_pretrained_weights
model.load_state_dict(state_dict, strict=True)
File "/home/user/anaconda3/envs/vissl/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ResNet:
Missing key(s) in state_dict: "fc.weight", "fc.bias".
Issue Analytics
- State:
- Created 2 years ago
- Comments:6 (3 by maintainers)
Top Results From Across the Web
Missing/unexpected keys in resnet50 with pytorch
You have changed your model, and as a result, the keys have changed. So, you are getting a mismatch error. I think you...
Read more >Missing keys & unexpected keys in state_dict when loading ...
I have the error Missing key(s) in state_dict: and actually when I save the model, I just use torch.save(). I am new to...
Read more >Missing keys in RobertaForMaskedLM state dict - Transformers
Hello, I am training my model that is based on RobertaForMaskedLM and after saving ... and inspect the state_dict.keys() I observe that two...
Read more >scvi.nn.Encoder.load_state_dict - scvi-tools
... match the keys returned by this module's state_dict() function. Default: True. Returns. missing_keys is a list of str containing the missing keys....
Read more >Unexpected key(s) in state_dict: - Programmer Sought
Missing key (s) in state_dict: Unexpected key(s) in state_dict:, Programmer Sought, ... DataParallel() before, and the training is not used at this time, ......
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hi @woctezuma,
Thanks for a quick response. I tried your suggested solution by setting
strict=False
and the error is gone. However, this leads to the following unexpected result for--arch resnet50
. ❌Nevertheless, not changing
strict=True
but addingmodel.fc=nn.Identity()
after line 103 in https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/utils.py#L102-L107 provided the expected result as follows. ✔️Good job! 😄
You get Top1 67.49% with 20-NN, which is as reported in Table 2 in the article.