question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

RuntimeError when loading model checkpoint

See original GitHub issue

Hi @andi611 @leo19941227

When I am trying to load the model state dict for my pytorch model, I am getting the following error:

  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1406, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for CustomModel:
        size mismatch for extractor_s3prl.extracter._melscale.fb: copying a param with shape torch.Size([513, 64]) from checkpoint, the shape in current model is torch.Size([0]).

Can you guide me on how to resolve the same?

Below is just a snippet roughly representing my model code for your reference. After training my model, I saved its model state.

class CustomModel(nn.Module):
    def __init__(self,  ...):

        super(CustomModel, self).__init__()

        self.s3prl_feature_type = 'baseline_local'
        self.extractor_s3prl =  torch.hub.load('s3prl/s3prl', self.s3prl_feature_type, model_config='path_to_custom_mel.yaml').to('cuda')

Sincerely, Soham Tiwari

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:7 (5 by maintainers)

github_iconTop GitHub Comments

1reaction
andi611commented, Sep 14, 2021

@leo19941227 I found this part in the source code:

if n_stft is None or n_stft == 0:
    warnings.warn(
        'Initialization of torchaudio.transforms.MelScale with an unset weight '
        '`n_stft=None` is deprecated and will be removed in release 0.10. '
        'Please set a proper `n_stft` value. Typically this is `n_fft // 2 + 1`. '
        'Refer to https://github.com/pytorch/audio/issues/1510 '
        'for more details.'

Should we pass an n_stft value when we initialize MelScale?

0reactions
leo19941227commented, Sep 16, 2021

Hey actually this is torchaudio’s bug and will be resolved after torchaudio==v0.10.0 You can refer to this issue and I provide a workaround below with adding just one line:

import torch
import torch.nn as nn
from s3prl.hub import baseline_local

class CustomModel(nn.Module):
    def __init__(self):

        super(CustomModel, self).__init__()

        self.extractor_s3prl = baseline_local(model_config="mel.yaml").to("cuda") 

    def forward(self, x):

        wavs = [wav.to('cuda') for wav in x.squeeze(1)]
        out_s3prl = self.extractor_s3prl(wavs)['last_hidden_state']
        x = torch.unsqueeze(out_s3prl, 1)

        output_dict = {
            'x': x
        }

        return output_dict
    
model = CustomModel()
x = torch.rand([1, 192000]).to('cuda')
out = model(x)
checkpoint = {
    'model': model.state_dict()
}
torch.save(checkpoint, 'custom_model.pth')

# loading state dict
model = CustomModel()

# ADD THIS LINE
out = model(x)

checkpoint = torch.load('custom_model.pth')
model.load_state_dict(checkpoint['model'])
Read more comments on GitHub >

github_iconTop Results From Across the Web

Error on loading checkpoint when training · Issue #26 - GitHub
Loading a model checkpoint as so: !python train.py --ckpt ... __name__, "\n\t".join(error_msgs))) RuntimeError: Error(s) in loading ...
Read more >
OSError: Unable to load weights from pytorch checkpoint file
4 while the torch version I used when pretraining my model and saving my checkpoint was v1.9 (I pretrained my models on one...
Read more >
Size Mismatch Runtime Error When Trying to Load a PyTorch ...
It seems to me that your model configuration does not match the content of the model checkpoint. I imagine your model has parameters...
Read more >
RuntimeError when loading model - vision - PyTorch Forums
Hello,. I got the following error when I tried to load my model: RuntimeError: Error(s) in loading state_dict for VGG:.
Read more >
Loading checkpoints when models built using a 'setup' block
I've found two especially troublesome issues. First, the load_from_checkpoint staticmethod fails. Second, when I manually load weights by first ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found