Can't use pretrained MVIT Model with provided cfg
See original GitHub issueif __name__ == '__main__':
ckpt = torch.load('models/K400_MVIT_B_16x4_CONV.pyth')
# cfg = CN.load_cfg(ckpt['cfg'])
model = MViT()
model.load_state_dict(ckpt['model_state'],strict=True)
input = torch.randn([3,3,16,224,224])
res = model(input)
print(res.shape)
MVIT is same as code in SlowFast. With Provided Cfg : https://github.com/facebookresearch/SlowFast/blob/master/configs/Kinetics/MVIT_B_16x4_CONV.yaml it shows like:
Missing key(s) in state_dict: "blocks.0.attn.q.weight", "blocks.0.attn.q.bias", "blocks.0.attn.k.weight", "blocks.0.attn.k.bias", "blocks.0.attn.v.weight", "blocks.0.attn.v.bias", "blocks.1.attn.q.weight", "blocks.1.attn.q.bias", "blocks.1.attn.k.weight", "blocks.1.attn.k.bias", "blocks.1.attn.v.weight", "blocks.1.attn.v.bias", "blocks.2.attn.q.weight", "blocks.2.attn.q.bias", "blocks.2.attn.k.weight", "blocks.2.attn.k.bias", "blocks.2.attn.v.weight", "blocks.2.attn.v.bias", "blocks.3.attn.q.weight", "blocks.3.attn.q.bias", "blocks.3.attn.k.weight", "blocks.3.attn.k.bias", "blocks.3.attn.v.weight", "blocks.3.attn.v.bias", "blocks.4.attn.q.weight", "blocks.4.attn.q.bias", "blocks.4.attn.k.weight", "blocks.4.attn.k.bias", "blocks.4.attn.v.weight", "blocks.4.attn.v.bias", "blocks.5.attn.q.weight", "blocks.5.attn.q.bias", "blocks.5.attn.k.weight", "blocks.5.attn.k.bias", "blocks.5.attn.v.weight", "blocks.5.attn.v.bias", "blocks.6.attn.q.weight", "blocks.6.attn.q.bias", "blocks.6.attn.k.weight", "blocks.6.attn.k.bias", "blocks.6.attn.v.weight", "blocks.6.attn.v.bias", "blocks.7.attn.q.weight", "blocks.7.attn.q.bias", "blocks.7.attn.k.weight", "blocks.7.attn.k.bias", "blocks.7.attn.v.weight", "blocks.7.attn.v.bias", "blocks.8.attn.q.weight", "blocks.8.attn.q.bias", "blocks.8.attn.k.weight", "blocks.8.attn.k.bias", "blocks.8.attn.v.weight", "blocks.8.attn.v.bias", "blocks.9.attn.q.weight", "blocks.9.attn.q.bias", "blocks.9.attn.k.weight", "blocks.9.attn.k.bias", "blocks.9.attn.v.weight", "blocks.9.attn.v.bias", "blocks.10.attn.q.weight", "blocks.10.attn.q.bias", "blocks.10.attn.k.weight", "blocks.10.attn.k.bias", "blocks.10.attn.v.weight", "blocks.10.attn.v.bias", "blocks.11.attn.q.weight", "blocks.11.attn.q.bias", "blocks.11.attn.k.weight", "blocks.11.attn.k.bias", "blocks.11.attn.v.weight", "blocks.11.attn.v.bias", "blocks.12.attn.q.weight", "blocks.12.attn.q.bias", "blocks.12.attn.k.weight", "blocks.12.attn.k.bias", "blocks.12.attn.v.weight", "blocks.12.attn.v.bias", "blocks.13.attn.q.weight", "blocks.13.attn.q.bias", "blocks.13.attn.k.weight", "blocks.13.attn.k.bias", "blocks.13.attn.v.weight", "blocks.13.attn.v.bias", "blocks.14.attn.q.weight", "blocks.14.attn.q.bias", "blocks.14.attn.k.weight", "blocks.14.attn.k.bias", "blocks.14.attn.v.weight", "blocks.14.attn.v.bias", "blocks.15.attn.q.weight", "blocks.15.attn.q.bias", "blocks.15.attn.k.weight", "blocks.15.attn.k.bias", "blocks.15.attn.v.weight", "blocks.15.attn.v.bias".
Unexpected key(s) in state_dict: "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.norm_k.weight", "blocks.0.attn.norm_k.bias", "blocks.0.attn.norm_v.weight", "blocks.0.attn.norm_v.bias", "blocks.0.attn.pool_k.weight", "blocks.0.attn.pool_v.weight", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.norm_q.weight", "blocks.1.attn.norm_q.bias", "blocks.1.attn.norm_k.weight", "blocks.1.attn.norm_k.bias", "blocks.1.attn.norm_v.weight", "blocks.1.attn.norm_v.bias", "blocks.1.attn.pool_q.weight", "blocks.1.attn.pool_k.weight", "blocks.1.attn.pool_v.weight", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.norm_k.weight", "blocks.2.attn.norm_k.bias", "blocks.2.attn.norm_v.weight", "blocks.2.attn.norm_v.bias", "blocks.2.attn.pool_k.weight", "blocks.2.attn.pool_v.weight", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.norm_q.weight", "blocks.3.attn.norm_q.bias", "blocks.3.attn.norm_k.weight", "blocks.3.attn.norm_k.bias", "blocks.3.attn.norm_v.weight", "blocks.3.attn.norm_v.bias", "blocks.3.attn.pool_q.weight", "blocks.3.attn.pool_k.weight", "blocks.3.attn.pool_v.weight", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.norm_k.weight", "blocks.4.attn.norm_k.bias", "blocks.4.attn.norm_v.weight", "blocks.4.attn.norm_v.bias", "blocks.4.attn.pool_k.weight", "blocks.4.attn.pool_v.weight", "blocks.5.attn.qkv.weight", "blocks.5.attn.qkv.bias", "blocks.5.attn.norm_k.weight", "blocks.5.attn.norm_k.bias", "blocks.5.attn.norm_v.weight", "blocks.5.attn.norm_v.bias", "blocks.5.attn.pool_k.weight", "blocks.5.attn.pool_v.weight", "blocks.6.attn.qkv.weight", "blocks.6.attn.qkv.bias", "blocks.6.attn.norm_k.weight", "blocks.6.attn.norm_k.bias", "blocks.6.attn.norm_v.weight", "blocks.6.attn.norm_v.bias", "blocks.6.attn.pool_k.weight", "blocks.6.attn.pool_v.weight", "blocks.7.attn.qkv.weight", "blocks.7.attn.qkv.bias", "blocks.7.attn.norm_k.weight", "blocks.7.attn.norm_k.bias", "blocks.7.attn.norm_v.weight", "blocks.7.attn.norm_v.bias", "blocks.7.attn.pool_k.weight", "blocks.7.attn.pool_v.weight", "blocks.8.attn.qkv.weight", "blocks.8.attn.qkv.bias", "blocks.8.attn.norm_k.weight", "blocks.8.attn.norm_k.bias", "blocks.8.attn.norm_v.weight", "blocks.8.attn.norm_v.bias", "blocks.8.attn.pool_k.weight", "blocks.8.attn.pool_v.weight", "blocks.9.attn.qkv.weight", "blocks.9.attn.qkv.bias", "blocks.9.attn.norm_k.weight", "blocks.9.attn.norm_k.bias", "blocks.9.attn.norm_v.weight", "blocks.9.attn.norm_v.bias", "blocks.9.attn.pool_k.weight", "blocks.9.attn.pool_v.weight", "blocks.10.attn.qkv.weight", "blocks.10.attn.qkv.bias", "blocks.10.attn.norm_k.weight", "blocks.10.attn.norm_k.bias", "blocks.10.attn.norm_v.weight", "blocks.10.attn.norm_v.bias", "blocks.10.attn.pool_k.weight", "blocks.10.attn.pool_v.weight", "blocks.11.attn.qkv.weight", "blocks.11.attn.qkv.bias", "blocks.11.attn.norm_k.weight", "blocks.11.attn.norm_k.bias", "blocks.11.attn.norm_v.weight", "blocks.11.attn.norm_v.bias", "blocks.11.attn.pool_k.weight", "blocks.11.attn.pool_v.weight", "blocks.12.attn.qkv.weight", "blocks.12.attn.qkv.bias", "blocks.12.attn.norm_k.weight", "blocks.12.attn.norm_k.bias", "blocks.12.attn.norm_v.weight", "blocks.12.attn.norm_v.bias", "blocks.12.attn.pool_k.weight", "blocks.12.attn.pool_v.weight", "blocks.13.attn.qkv.weight", "blocks.13.attn.qkv.bias", "blocks.13.attn.norm_k.weight", "blocks.13.attn.norm_k.bias", "blocks.13.attn.norm_v.weight", "blocks.13.attn.norm_v.bias", "blocks.13.attn.pool_k.weight", "blocks.13.attn.pool_v.weight", "blocks.14.attn.qkv.weight", "blocks.14.attn.qkv.bias", "blocks.14.attn.norm_q.weight", "blocks.14.attn.norm_q.bias", "blocks.14.attn.pool_q.weight", "blocks.15.attn.qkv.weight", "blocks.15.attn.qkv.bias".
it seems like the source code of MVIT has some differences with origin pretrained MVIT’s
Issue Analytics
- State:
- Created 2 years ago
- Comments:25 (1 by maintainers)
Top Results From Across the Web
Sharing pretrained models - Hugging Face Course
In the steps below, we'll take a look at the easiest ways to share pretrained models to the 🤗 Hub. There are tools...
Read more >Source code for torchvision.models.video.mvit - PyTorch
Source code for torchvision.models.video.mvit ... C) + thw).contiguous() # normalizing prior pooling is useful when we use BN which can be absorbed to...
Read more >arXiv:2111.11591v2 [cs.CV] 16 Jul 2022
Given that the pretrained models on MViT [14] are not available on SSV2, we only evaluate STTS using the VideoSwin Transformer as backbone....
Read more >Masked Feature Prediction for Self-Supervised Visual Pre ...
Our model is learned by predicting features (middle) given masked inputs ... We use a pre-trained model to produce features as a teacher, ......
Read more >ST-Adapter: Parameter-Efficient Image-to-Video Transfer ...
Capitalizing on large pre-trained models for various downstream tasks of interest ... model with sufficient knowledge is less or not available.
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hi All, Thanks for playing with PySlowFast, we have some minor upgrade on the code, which the old pretrain checkpoint is not compatible, I’ll update the new checkpoint very soon
Thanks
It should work if you replace the following lines (self.q, self.k, self.v) for self.qkv