question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Adding an argument to exclude some states (pretrained weights) from being loaded.

See original GitHub issue

🚀 Feature request

Adding an argument in from_pretrained to exclude some states (pretrained weights) from being loaded.

Motivation

In general, we usually use from_pretrained method to load pretrained states, from CDN or local files, into the model. However, In case when I need to adjust the shape of certain layers (submodules), the errors will be raised due to mismatched shapes.

For example, in the following snippets, I changed the embedding_size of Electra in order to tie the same embeddings as BERT in the subsequent code, but due to the mismatched shapes, many RuntimeErrors were raised in module._load_from_state_dict.

from transformers import BertModel, BertConfig, ElectraModel, ElectraConfig

bert_config = BertConfig.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

electra_config = ElectraConfig.from_pretrained(
    'google/electra-small-generator',
    embedding_size=bert_config.hidden_size
)
electra_model = ElectraModel.from_pretrained('google/electra-small-generator', config=electra_config)
Exception has occurred: RuntimeError
Error(s) in loading state_dict for ElectraModel:
    size mismatch for electra.embeddings.word_embeddings.weight: copying a param with shape torch.Size([30522, 128]) from checkpoint, the shape in current model is torch.Size([30522, 768]).
    size mismatch for electra.embeddings.position_embeddings.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([512, 768]).
    size mismatch for electra.embeddings.token_type_embeddings.weight: copying a param with shape torch.Size([2, 128]) from checkpoint, the shape in current model is torch.Size([2, 768]).
    size mismatch for electra.embeddings.LayerNorm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for electra.embeddings.LayerNorm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([768]).
    size mismatch for electra.embeddings_project.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([256, 768]).

Therefore, I think it would be better to add an argument like excluded_keys (as the following example) in from_pretrained to explicitly prevent certain states from being loaded or add an argument to automatically have the states with mismatched shapes not loaded. I know there are some workarounds such as loading all states first then tying each weight respectively, but that will result in a long and not concise code segment.

Example:

electra_model = ElectraModel.from_pretrained(
    'google/electra-small-generator',
    config=electra_config,
    excluded_keys = [
        "electra.embeddings.word_embeddings.weight",
        "electra.embeddings.position_embeddings.weight",
        "electra.embeddings.token_type_embeddings.weight",
        "electra.embeddings.LayerNorm.weight",
        "electra.embeddings.LayerNorm.bias",
        "electra.embeddings_project.weight",
        "generator_predictions.LayerNorm.weight",
        "generator_predictions.LayerNorm.bias",
        "generator_predictions.dense.weight",
        "generator_predictions.dense.bias",
        "generator_lm_head.weight"
    ]
)

Your contribution

If there is no other concern, and no one is implementing similar features, I would be happy to submit a PR for this.

Any thoughts are welcomed 😃

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:2
  • Comments:6 (6 by maintainers)

github_iconTop GitHub Comments

5reactions
sguggercommented, Jul 12, 2021

After discussing offline with @LysandreJik we will add a ignore_mismatched_size flag to from_pretrained. When activated, weights that don’t have the right size will be ignored, which should cover both the use cases in this issue.

I will work on this today.

2reactions
qqaatwcommented, Jul 13, 2021

The feature has been implemented in #12664. Thanks @sgugger

Read more comments on GitHub >

github_iconTop Results From Across the Web

Models - Hugging Face
The model is first created on the Meta device (with empty weights) and the state dict is then loaded inside it (shard by...
Read more >
How to set parameters in keras to be non-trainable?
To "freeze" a layer means to exclude it from training, i.e. its weights will never be updated. This is useful in the context...
Read more >
Transfer learning and fine-tuning | TensorFlow Core
Instantiate a base model and load pre-trained weights into it. Run your new dataset through it and record the output of one (or...
Read more >
How to load part of pre trained model? - PyTorch Forums
BTW,If the key are same, but size are different, this argument can't handle it. We need to remove the variables with different size...
Read more >
sparseml.pytorch.models.classification package
pretrained_dataset – The dataset to load pretrained weights for (ex: imagenet, ... Tensors to ignore while checking the state dict for weights loaded...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found