Problem loading a finetuned model.
See original GitHub issueHi!
There is a problem with the way model are saved and loaded. The following code should crash and doesn’t:
import torch
from pytorch_pretrained_bert import BertForSequenceClassification
model_fn = 'model.bin'
bert_model = 'bert-base-multilingual-cased'
model = BertForSequenceClassification.from_pretrained(bert_model, num_labels = 16)
model_to_save = model.module if hasattr(model, 'module') else model
torch.save(model_to_save.state_dict(), model_fn)
print(model_to_save.num_labels)
model_state_dict = torch.load(model_fn)
loaded_model = BertForSequenceClassification.from_pretrained(bert_model, state_dict = model_state_dict)
print(loaded_model.num_labels)
This code prints:
16
2
The code should raise an exception when trying to load the weights of the task specific linear layer. I’m guessing that the problem comes from PreTrainedBertModel.from_pretrained.
I would be happy to submit a PR fixing this problem but I’m not used to work with the PyTorch loading mechanisms. @thomwolf could you give me some guidance?
Cheers!
Issue Analytics
- State:
- Created 5 years ago
- Comments:7 (4 by maintainers)
Top Results From Across the Web
Problem saving and/or loading fine-tuned model #3430 - GitHub
My advice to solve this wolud be the following: Train for 1 epoch and a tiny part of the dataset. Print out /...
Read more >Load fine tuned model from local - Hugging Face Forums
Hey, if I fine tune a BERT model is the tokneizer somehow affected? ... OSError: Model name 'Fine_tune_BERT/' was not found in tokenizers...
Read more >Loading fine-tuned model built from pretrained subnetworks
The problem I have is working with remote servers, the paths change in between the fine-tuning run and another test run so in...
Read more >How do I load a fine-tuned AllenNLP BERT-SRL model using ...
I believe I have figured this out. Basically, I had to re-load my model archive, access the underlying model and tokenizer, and then...
Read more >huggingface save fine tuned model - You.com | The AI Search ...
Interestingly, the print statements reveal in the log that one of the workers seems to load the model successfully but the other does...
Read more >
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found

Just use the num_labels when you load your model
With the latest transformers versions, you can use the recently introduced (https://github.com/huggingface/transformers/pull/12664)
ignore_mismatched_sizes=Trueparameter forfrom_pretrainedmethod in order to specify that you’d rather drop the layers that have incompatible shapes rather than raise aRuntimeError.