question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Can't use pretrained MVIT Model with provided cfg

See original GitHub issue
if __name__ == '__main__':
    ckpt = torch.load('models/K400_MVIT_B_16x4_CONV.pyth')
    # cfg = CN.load_cfg(ckpt['cfg'])
    model = MViT()
    model.load_state_dict(ckpt['model_state'],strict=True)
    input = torch.randn([3,3,16,224,224])
    res = model(input)
    print(res.shape)

MVIT is same as code in SlowFast. With Provided Cfg : https://github.com/facebookresearch/SlowFast/blob/master/configs/Kinetics/MVIT_B_16x4_CONV.yaml it shows like:

        Missing key(s) in state_dict: "blocks.0.attn.q.weight", "blocks.0.attn.q.bias", "blocks.0.attn.k.weight", "blocks.0.attn.k.bias", "blocks.0.attn.v.weight", "blocks.0.attn.v.bias", "blocks.1.attn.q.weight", "blocks.1.attn.q.bias", "blocks.1.attn.k.weight", "blocks.1.attn.k.bias", "blocks.1.attn.v.weight", "blocks.1.attn.v.bias", "blocks.2.attn.q.weight", "blocks.2.attn.q.bias", "blocks.2.attn.k.weight", "blocks.2.attn.k.bias", "blocks.2.attn.v.weight", "blocks.2.attn.v.bias", "blocks.3.attn.q.weight", "blocks.3.attn.q.bias", "blocks.3.attn.k.weight", "blocks.3.attn.k.bias", "blocks.3.attn.v.weight", "blocks.3.attn.v.bias", "blocks.4.attn.q.weight", "blocks.4.attn.q.bias", "blocks.4.attn.k.weight", "blocks.4.attn.k.bias", "blocks.4.attn.v.weight", "blocks.4.attn.v.bias", "blocks.5.attn.q.weight", "blocks.5.attn.q.bias", "blocks.5.attn.k.weight", "blocks.5.attn.k.bias", "blocks.5.attn.v.weight", "blocks.5.attn.v.bias", "blocks.6.attn.q.weight", "blocks.6.attn.q.bias", "blocks.6.attn.k.weight", "blocks.6.attn.k.bias", "blocks.6.attn.v.weight", "blocks.6.attn.v.bias", "blocks.7.attn.q.weight", "blocks.7.attn.q.bias", "blocks.7.attn.k.weight", "blocks.7.attn.k.bias", "blocks.7.attn.v.weight", "blocks.7.attn.v.bias", "blocks.8.attn.q.weight", "blocks.8.attn.q.bias", "blocks.8.attn.k.weight", "blocks.8.attn.k.bias", "blocks.8.attn.v.weight", "blocks.8.attn.v.bias", "blocks.9.attn.q.weight", "blocks.9.attn.q.bias", "blocks.9.attn.k.weight", "blocks.9.attn.k.bias", "blocks.9.attn.v.weight", "blocks.9.attn.v.bias", "blocks.10.attn.q.weight", "blocks.10.attn.q.bias", "blocks.10.attn.k.weight", "blocks.10.attn.k.bias", "blocks.10.attn.v.weight", "blocks.10.attn.v.bias", "blocks.11.attn.q.weight", "blocks.11.attn.q.bias", "blocks.11.attn.k.weight", "blocks.11.attn.k.bias", "blocks.11.attn.v.weight", "blocks.11.attn.v.bias", "blocks.12.attn.q.weight", "blocks.12.attn.q.bias", "blocks.12.attn.k.weight", "blocks.12.attn.k.bias", "blocks.12.attn.v.weight", "blocks.12.attn.v.bias", "blocks.13.attn.q.weight", "blocks.13.attn.q.bias", "blocks.13.attn.k.weight", "blocks.13.attn.k.bias", "blocks.13.attn.v.weight", "blocks.13.attn.v.bias", "blocks.14.attn.q.weight", "blocks.14.attn.q.bias", "blocks.14.attn.k.weight", "blocks.14.attn.k.bias", "blocks.14.attn.v.weight", "blocks.14.attn.v.bias", "blocks.15.attn.q.weight", "blocks.15.attn.q.bias", "blocks.15.attn.k.weight", "blocks.15.attn.k.bias", "blocks.15.attn.v.weight", "blocks.15.attn.v.bias". 
        Unexpected key(s) in state_dict: "blocks.0.attn.qkv.weight", "blocks.0.attn.qkv.bias", "blocks.0.attn.norm_k.weight", "blocks.0.attn.norm_k.bias", "blocks.0.attn.norm_v.weight", "blocks.0.attn.norm_v.bias", "blocks.0.attn.pool_k.weight", "blocks.0.attn.pool_v.weight", "blocks.1.attn.qkv.weight", "blocks.1.attn.qkv.bias", "blocks.1.attn.norm_q.weight", "blocks.1.attn.norm_q.bias", "blocks.1.attn.norm_k.weight", "blocks.1.attn.norm_k.bias", "blocks.1.attn.norm_v.weight", "blocks.1.attn.norm_v.bias", "blocks.1.attn.pool_q.weight", "blocks.1.attn.pool_k.weight", "blocks.1.attn.pool_v.weight", "blocks.2.attn.qkv.weight", "blocks.2.attn.qkv.bias", "blocks.2.attn.norm_k.weight", "blocks.2.attn.norm_k.bias", "blocks.2.attn.norm_v.weight", "blocks.2.attn.norm_v.bias", "blocks.2.attn.pool_k.weight", "blocks.2.attn.pool_v.weight", "blocks.3.attn.qkv.weight", "blocks.3.attn.qkv.bias", "blocks.3.attn.norm_q.weight", "blocks.3.attn.norm_q.bias", "blocks.3.attn.norm_k.weight", "blocks.3.attn.norm_k.bias", "blocks.3.attn.norm_v.weight", "blocks.3.attn.norm_v.bias", "blocks.3.attn.pool_q.weight", "blocks.3.attn.pool_k.weight", "blocks.3.attn.pool_v.weight", "blocks.4.attn.qkv.weight", "blocks.4.attn.qkv.bias", "blocks.4.attn.norm_k.weight", "blocks.4.attn.norm_k.bias", "blocks.4.attn.norm_v.weight", "blocks.4.attn.norm_v.bias", "blocks.4.attn.pool_k.weight", "blocks.4.attn.pool_v.weight", "blocks.5.attn.qkv.weight", "blocks.5.attn.qkv.bias", "blocks.5.attn.norm_k.weight", "blocks.5.attn.norm_k.bias", "blocks.5.attn.norm_v.weight", "blocks.5.attn.norm_v.bias", "blocks.5.attn.pool_k.weight", "blocks.5.attn.pool_v.weight", "blocks.6.attn.qkv.weight", "blocks.6.attn.qkv.bias", "blocks.6.attn.norm_k.weight", "blocks.6.attn.norm_k.bias", "blocks.6.attn.norm_v.weight", "blocks.6.attn.norm_v.bias", "blocks.6.attn.pool_k.weight", "blocks.6.attn.pool_v.weight", "blocks.7.attn.qkv.weight", "blocks.7.attn.qkv.bias", "blocks.7.attn.norm_k.weight", "blocks.7.attn.norm_k.bias", "blocks.7.attn.norm_v.weight", "blocks.7.attn.norm_v.bias", "blocks.7.attn.pool_k.weight", "blocks.7.attn.pool_v.weight", "blocks.8.attn.qkv.weight", "blocks.8.attn.qkv.bias", "blocks.8.attn.norm_k.weight", "blocks.8.attn.norm_k.bias", "blocks.8.attn.norm_v.weight", "blocks.8.attn.norm_v.bias", "blocks.8.attn.pool_k.weight", "blocks.8.attn.pool_v.weight", "blocks.9.attn.qkv.weight", "blocks.9.attn.qkv.bias", "blocks.9.attn.norm_k.weight", "blocks.9.attn.norm_k.bias", "blocks.9.attn.norm_v.weight", "blocks.9.attn.norm_v.bias", "blocks.9.attn.pool_k.weight", "blocks.9.attn.pool_v.weight", "blocks.10.attn.qkv.weight", "blocks.10.attn.qkv.bias", "blocks.10.attn.norm_k.weight", "blocks.10.attn.norm_k.bias", "blocks.10.attn.norm_v.weight", "blocks.10.attn.norm_v.bias", "blocks.10.attn.pool_k.weight", "blocks.10.attn.pool_v.weight", "blocks.11.attn.qkv.weight", "blocks.11.attn.qkv.bias", "blocks.11.attn.norm_k.weight", "blocks.11.attn.norm_k.bias", "blocks.11.attn.norm_v.weight", "blocks.11.attn.norm_v.bias", "blocks.11.attn.pool_k.weight", "blocks.11.attn.pool_v.weight", "blocks.12.attn.qkv.weight", "blocks.12.attn.qkv.bias", "blocks.12.attn.norm_k.weight", "blocks.12.attn.norm_k.bias", "blocks.12.attn.norm_v.weight", "blocks.12.attn.norm_v.bias", "blocks.12.attn.pool_k.weight", "blocks.12.attn.pool_v.weight", "blocks.13.attn.qkv.weight", "blocks.13.attn.qkv.bias", "blocks.13.attn.norm_k.weight", "blocks.13.attn.norm_k.bias", "blocks.13.attn.norm_v.weight", "blocks.13.attn.norm_v.bias", "blocks.13.attn.pool_k.weight", "blocks.13.attn.pool_v.weight", "blocks.14.attn.qkv.weight", "blocks.14.attn.qkv.bias", "blocks.14.attn.norm_q.weight", "blocks.14.attn.norm_q.bias", "blocks.14.attn.pool_q.weight", "blocks.15.attn.qkv.weight", "blocks.15.attn.qkv.bias". 

it seems like the source code of MVIT has some differences with origin pretrained MVIT’s

Issue Analytics

  • State:open
  • Created 2 years ago
  • Comments:25 (1 by maintainers)

github_iconTop GitHub Comments

10reactions
haooooooqicommented, Jul 31, 2021

Hi All, Thanks for playing with PySlowFast, we have some minor upgrade on the code, which the old pretrain checkpoint is not compatible, I’ll update the new checkpoint very soon

Thanks

5reactions
juntingcommented, Jan 21, 2022

It should work if you replace the following lines (self.q, self.k, self.v) for self.qkv

        # q = (
        #     self.q(q)
        #     .reshape(B, N, self.num_heads, C // self.num_heads)
        #     .permute(0, 2, 1, 3)
        # )
        # k = (
        #     self.k(k)
        #     .reshape(B, N, self.num_heads, C // self.num_heads)
        #     .permute(0, 2, 1, 3)
        # )
        # v = (
        #     self.v(v)
        #     .reshape(B, N, self.num_heads, C // self.num_heads)
        #     .permute(0, 2, 1, 3)
        # )

        qkv = (
            self.qkv(x)
                .reshape(B, N, 3, self.num_heads, C // self.num_heads)
                .permute(2, 0, 3, 1, 4)
        )
        q, k, v = qkv[0], qkv[1], qkv[2]
Read more comments on GitHub >

github_iconTop Results From Across the Web

Sharing pretrained models - Hugging Face Course
In the steps below, we'll take a look at the easiest ways to share pretrained models to the 🤗 Hub. There are tools...
Read more >
Source code for torchvision.models.video.mvit - PyTorch
Source code for torchvision.models.video.mvit ... C) + thw).contiguous() # normalizing prior pooling is useful when we use BN which can be absorbed to...
Read more >
arXiv:2111.11591v2 [cs.CV] 16 Jul 2022
Given that the pretrained models on MViT [14] are not available on SSV2, we only evaluate STTS using the VideoSwin Transformer as backbone....
Read more >
Masked Feature Prediction for Self-Supervised Visual Pre ...
Our model is learned by predicting features (middle) given masked inputs ... We use a pre-trained model to produce features as a teacher, ......
Read more >
ST-Adapter: Parameter-Efficient Image-to-Video Transfer ...
Capitalizing on large pre-trained models for various downstream tasks of interest ... model with sufficient knowledge is less or not available.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found