Wrong LayerNorm weight names in "bert-base-uncased" checkpoint ?
See original GitHub issueEnvironment info
transformers
version: 4.4.0.dev0- Platform: Linux-4.18.0-147.44.1.el8_1.x86_64-x86_64-with-glibc2.10
- Python version: 3.8.5
- PyTorch version (GPU?): 1.7.1 (True)
- Tensorflow version (GPU?): not installed (NA)
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No
Who can help
@patrickvonplaten (issue with “bert-base-uncased” checkpoint)
Information
Model I am using (Bert, XLNet …): BERT(base, uncased)
The problem arises when: loading “bert-base-uncased” model weights from state_dict
To reproduce
Steps to reproduce the behavior:
- Download model checkpoint from hub:
git lfs install
git clone https://huggingface.co/bert-base-uncased
- Load pre-trained model from checkpoint using
.from_pretrained
(this sort of works)
import torch
from transformers import BertForPreTraining
model = BertForPretraining.from_pretrained('./bert-base-uncased')
"""
[Output]:
Some weights of BertForPreTraining were not initialized from the model checkpoint at ./bert-base-uncased and are newly initialized: ['cls.predictions.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
"""
- Re-load same weights, this time using
.load_state_dict
state_dict = torch.load('./bert-base-uncased/pytorch_model.bin')
model.load_state_dict(state_dict)
This fails and outputs:
RuntimeError: Error(s) in loading state_dict for BertForPreTraining:
Missing key(s) in state_dict: "bert.embeddings.position_ids", "bert.embeddings.LayerNorm.weight", "bert.embeddings.LayerNorm.bias", "bert.encoder.layer.0.attention.output.LayerNorm.weight", "bert.encoder.layer.0.attention.output.LayerNorm.bias", "bert.encoder.layer.0.output.LayerNorm.weight", "bert.encoder.layer.0.output.LayerNorm.bias", "bert.encoder.layer.1.attention.output.LayerNorm.weight", "bert.encoder.layer.1.attention.output.LayerNorm.bias", "bert.encoder.layer.1.output.LayerNorm.weight", "bert.encoder.layer.1.output.LayerNorm.bias", "bert.encoder.layer.2.attention.output.LayerNorm.weight", "bert.encoder.layer.2.attention.output.LayerNorm.bias", "bert.encoder.layer.2.output.LayerNorm.weight", "bert.encoder.layer.2.output.LayerNorm.bias", "bert.encoder.layer.3.attention.output.LayerNorm.weight", "bert.encoder.layer.3.attention.output.LayerNorm.bias", "bert.encoder.layer.3.output.LayerNorm.weight", "bert.encoder.layer.3.output.LayerNorm.bias", "bert.encoder.layer.4.attention.output.LayerNorm.weight", "bert.encoder.layer.4.attention.output.LayerNorm.bias", "bert.encoder.layer.4.output.LayerNorm.weight", "bert.encoder.layer.4.output.LayerNorm.bias", "bert.encoder.layer.5.attention.output.LayerNorm.weight", "bert.encoder.layer.5.attention.output.LayerNorm.bias", "bert.encoder.layer.5.output.LayerNorm.weight", "bert.encoder.layer.5.output.LayerNorm.bias", "bert.encoder.layer.6.attention.output.LayerNorm.weight", "bert.encoder.layer.6.attention.output.LayerNorm.bias", "bert.encoder.layer.6.output.LayerNorm.weight", "bert.encoder.layer.6.output.LayerNorm.bias", "bert.encoder.layer.7.attention.output.LayerNorm.weight", "bert.encoder.layer.7.attention.output.LayerNorm.bias", "bert.encoder.layer.7.output.LayerNorm.weight", "bert.encoder.layer.7.output.LayerNorm.bias", "bert.encoder.layer.8.attention.output.LayerNorm.weight", "bert.encoder.layer.8.attention.output.LayerNorm.bias", "bert.encoder.layer.8.output.LayerNorm.weight", "bert.encoder.layer.8.output.LayerNorm.bias", "bert.encoder.layer.9.attention.output.LayerNorm.weight", "bert.encoder.layer.9.attention.output.LayerNorm.bias", "bert.encoder.layer.9.output.LayerNorm.weight", "bert.encoder.layer.9.output.LayerNorm.bias", "bert.encoder.layer.10.attention.output.LayerNorm.weight", "bert.encoder.layer.10.attention.output.LayerNorm.bias", "bert.encoder.layer.10.output.LayerNorm.weight", "bert.encoder.layer.10.output.LayerNorm.bias", "bert.encoder.layer.11.attention.output.LayerNorm.weight", "bert.encoder.layer.11.attention.output.LayerNorm.bias", "bert.encoder.layer.11.output.LayerNorm.weight", "bert.encoder.layer.11.output.LayerNorm.bias", "cls.predictions.transform.LayerNorm.weight", "cls.predictions.transform.LayerNorm.bias", "cls.predictions.decoder.bias".
Unexpected key(s) in state_dict: "bert.embeddings.LayerNorm.gamma", "bert.embeddings.LayerNorm.beta", "bert.encoder.layer.0.attention.output.LayerNorm.gamma", "bert.encoder.layer.0.attention.output.LayerNorm.beta", "bert.encoder.layer.0.output.LayerNorm.gamma", "bert.encoder.layer.0.output.LayerNorm.beta", "bert.encoder.layer.1.attention.output.LayerNorm.gamma", "bert.encoder.layer.1.attention.output.LayerNorm.beta", "bert.encoder.layer.1.output.LayerNorm.gamma", "bert.encoder.layer.1.output.LayerNorm.beta", "bert.encoder.layer.2.attention.output.LayerNorm.gamma", "bert.encoder.layer.2.attention.output.LayerNorm.beta", "bert.encoder.layer.2.output.LayerNorm.gamma", "bert.encoder.layer.2.output.LayerNorm.beta", "bert.encoder.layer.3.attention.output.LayerNorm.gamma", "bert.encoder.layer.3.attention.output.LayerNorm.beta", "bert.encoder.layer.3.output.LayerNorm.gamma", "bert.encoder.layer.3.output.LayerNorm.beta", "bert.encoder.layer.4.attention.output.LayerNorm.gamma", "bert.encoder.layer.4.attention.output.LayerNorm.beta", "bert.encoder.layer.4.output.LayerNorm.gamma", "bert.encoder.layer.4.output.LayerNorm.beta", "bert.encoder.layer.5.attention.output.LayerNorm.gamma", "bert.encoder.layer.5.attention.output.LayerNorm.beta", "bert.encoder.layer.5.output.LayerNorm.gamma", "bert.encoder.layer.5.output.LayerNorm.beta", "bert.encoder.layer.6.attention.output.LayerNorm.gamma", "bert.encoder.layer.6.attention.output.LayerNorm.beta", "bert.encoder.layer.6.output.LayerNorm.gamma", "bert.encoder.layer.6.output.LayerNorm.beta", "bert.encoder.layer.7.attention.output.LayerNorm.gamma", "bert.encoder.layer.7.attention.output.LayerNorm.beta", "bert.encoder.layer.7.output.LayerNorm.gamma", "bert.encoder.layer.7.output.LayerNorm.beta", "bert.encoder.layer.8.attention.output.LayerNorm.gamma", "bert.encoder.layer.8.attention.output.LayerNorm.beta", "bert.encoder.layer.8.output.LayerNorm.gamma", "bert.encoder.layer.8.output.LayerNorm.beta", "bert.encoder.layer.9.attention.output.LayerNorm.gamma", "bert.encoder.layer.9.attention.output.LayerNorm.beta", "bert.encoder.layer.9.output.LayerNorm.gamma", "bert.encoder.layer.9.output.LayerNorm.beta", "bert.encoder.layer.10.attention.output.LayerNorm.gamma", "bert.encoder.layer.10.attention.output.LayerNorm.beta", "bert.encoder.layer.10.output.LayerNorm.gamma", "bert.encoder.layer.10.output.LayerNorm.beta", "bert.encoder.layer.11.attention.output.LayerNorm.gamma", "bert.encoder.layer.11.attention.output.LayerNorm.beta", "bert.encoder.layer.11.output.LayerNorm.gamma", "bert.encoder.layer.11.output.LayerNorm.beta", "cls.predictions.transform.LayerNorm.gamma", "cls.predictions.transform.LayerNorm.beta".
Expected behavior
Opening the checkpoint using torch.load
then loading these weights using model.load_state_dict
should result in matching all keys successfully (in particular here, all LayerNorm weights should be loaded).
Solution?
The issue here seems to be that the weight and bias parameters in LayerNorm were renamed from gamma and beta previously but the bert-base-uncased checkpoint wasn’t updated to reflect this change. I am using a somewhat older version of transformers / pytorch but this seems to be still the case in recent versions of both libraries. The test was done using the model checkpoint from the model hub on 21 May 2021.
Issue Analytics
- State:
- Created 2 years ago
- Comments:6
Top GitHub Comments
It seems like your case is a bit different. I think you are “initializing
XLMRobertaModel
from the checkpoint of a model trained on another task” (pretraining checkpoint). So you have some parameters that are not needed (those from the language modeling head)In my case, it is the layer norm parameters that have the wrong name regardless of which architecture I load 😃
Edit: basically what I mean is that your behaviour is expected while mine is not.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.