Adding an argument to exclude some states (pretrained weights) from being loaded.
See original GitHub issue🚀 Feature request
Adding an argument in from_pretrained
to exclude some states (pretrained weights) from being loaded.
Motivation
In general, we usually use from_pretrained
method to load pretrained states, from CDN or local files, into the model. However, In case when I need to adjust the shape of certain layers (submodules), the errors will be raised due to mismatched shapes.
For example, in the following snippets, I changed the embedding_size of Electra in order to tie the same embeddings as BERT in the subsequent code, but due to the mismatched shapes, many RuntimeErrors were raised in module._load_from_state_dict
.
from transformers import BertModel, BertConfig, ElectraModel, ElectraConfig
bert_config = BertConfig.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
electra_config = ElectraConfig.from_pretrained(
'google/electra-small-generator',
embedding_size=bert_config.hidden_size
)
electra_model = ElectraModel.from_pretrained('google/electra-small-generator', config=electra_config)
Exception has occurred: RuntimeError
Error(s) in loading state_dict for ElectraModel:
size mismatch for electra.embeddings.word_embeddings.weight: copying a param with shape torch.Size([30522, 128]) from checkpoint, the shape in current model is torch.Size([30522, 768]).
size mismatch for electra.embeddings.position_embeddings.weight: copying a param with shape torch.Size([512, 128]) from checkpoint, the shape in current model is torch.Size([512, 768]).
size mismatch for electra.embeddings.token_type_embeddings.weight: copying a param with shape torch.Size([2, 128]) from checkpoint, the shape in current model is torch.Size([2, 768]).
size mismatch for electra.embeddings.LayerNorm.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for electra.embeddings.LayerNorm.bias: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for electra.embeddings_project.weight: copying a param with shape torch.Size([256, 128]) from checkpoint, the shape in current model is torch.Size([256, 768]).
Therefore, I think it would be better to add an argument like excluded_keys
(as the following example) in from_pretrained
to explicitly prevent certain states from being loaded or add an argument to automatically have the states with mismatched shapes not loaded. I know there are some workarounds such as loading all states first then tying each weight respectively, but that will result in a long and not concise code segment.
Example:
electra_model = ElectraModel.from_pretrained(
'google/electra-small-generator',
config=electra_config,
excluded_keys = [
"electra.embeddings.word_embeddings.weight",
"electra.embeddings.position_embeddings.weight",
"electra.embeddings.token_type_embeddings.weight",
"electra.embeddings.LayerNorm.weight",
"electra.embeddings.LayerNorm.bias",
"electra.embeddings_project.weight",
"generator_predictions.LayerNorm.weight",
"generator_predictions.LayerNorm.bias",
"generator_predictions.dense.weight",
"generator_predictions.dense.bias",
"generator_lm_head.weight"
]
)
Your contribution
If there is no other concern, and no one is implementing similar features, I would be happy to submit a PR for this.
Any thoughts are welcomed 😃
Issue Analytics
- State:
- Created 2 years ago
- Reactions:2
- Comments:6 (6 by maintainers)
Top GitHub Comments
After discussing offline with @LysandreJik we will add a
ignore_mismatched_size
flag tofrom_pretrained
. When activated, weights that don’t have the right size will be ignored, which should cover both the use cases in this issue.I will work on this today.
The feature has been implemented in #12664. Thanks @sgugger