question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Problem loading a finetuned model.

See original GitHub issue

Hi!

There is a problem with the way model are saved and loaded. The following code should crash and doesn’t:

import torch
from pytorch_pretrained_bert import BertForSequenceClassification

model_fn = 'model.bin'
bert_model = 'bert-base-multilingual-cased'
model = BertForSequenceClassification.from_pretrained(bert_model, num_labels = 16)

model_to_save = model.module if hasattr(model, 'module') else model
torch.save(model_to_save.state_dict(), model_fn)
print(model_to_save.num_labels)

model_state_dict = torch.load(model_fn)
loaded_model = BertForSequenceClassification.from_pretrained(bert_model, state_dict = model_state_dict)
print(loaded_model.num_labels)

This code prints:

16
2

The code should raise an exception when trying to load the weights of the task specific linear layer. I’m guessing that the problem comes from PreTrainedBertModel.from_pretrained.

I would be happy to submit a PR fixing this problem but I’m not used to work with the PyTorch loading mechanisms. @thomwolf could you give me some guidance?

Cheers!

Issue Analytics

  • State:closed
  • Created 5 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

21reactions
HamidMoghaddamcommented, Jan 3, 2019

Just use the num_labels when you load your model

model_state_dict = torch.load(model_fn)
loaded_model = BertForSequenceClassification.from_pretrained(bert_model, state_dict = model_state_dict, num_labels = 16)
print(loaded_model.num_labels)```
17reactions
LysandreJikcommented, Aug 4, 2021

With the latest transformers versions, you can use the recently introduced (https://github.com/huggingface/transformers/pull/12664) ignore_mismatched_sizes=True parameter for from_pretrained method in order to specify that you’d rather drop the layers that have incompatible shapes rather than raise a RuntimeError.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Problem saving and/or loading fine-tuned model #3430 - GitHub
My advice to solve this wolud be the following: Train for 1 epoch and a tiny part of the dataset. Print out /...
Read more >
Load fine tuned model from local - Hugging Face Forums
Hey, if I fine tune a BERT model is the tokneizer somehow affected? ... OSError: Model name 'Fine_tune_BERT/' was not found in tokenizers...
Read more >
Loading fine-tuned model built from pretrained subnetworks
The problem I have is working with remote servers, the paths change in between the fine-tuning run and another test run so in...
Read more >
How do I load a fine-tuned AllenNLP BERT-SRL model using ...
I believe I have figured this out. Basically, I had to re-load my model archive, access the underlying model and tokenizer, and then...
Read more >
huggingface save fine tuned model - You.com | The AI Search ...
Interestingly, the print statements reveal in the log that one of the workers seems to load the model successfully but the other does...
Read more >

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found