question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Loading mBART Large 50 MMT (many-to-many) is slow

See original GitHub issue

Environment info

I’m installing the library directly from master and running it in a kaggle notebook.

  • transformers version: 4.4.0.dev0
  • Platform: Linux-5.4.89±x86_64-with-debian-buster-sid
  • Python version: 3.7.9
  • PyTorch version (GPU?): 1.7.0 (False)
  • Tensorflow version (GPU?): 2.4.1 (False)
  • Using GPU in script?: No
  • Using distributed or parallel set-up in script?: No

Who can help

Information

Model I am using (Bert, XLNet …): mBART-Large 50 MMT (many-to-many)

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

After caching the weights of the model, load it with from_pretrained is significantly slower compared with torch.load.

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

Machine Translation

To reproduce

Here’s the kaggle notebook reproducing the issue. Here’s a colab notebook showing essentially the same thing.

Steps to reproduce the behavior:

  1. Load model with model = MBartForConditionalGeneration.load_pretrained("facebook/mbart-large-50-many-to-many-mmt")
  2. Save model with model.save_pretrained('./my-model')
  3. Save model with torch.save(model, 'model.pt')
  4. Reload and time with MBartForConditionalGeneration.load_pretrained('./my-model')
  5. Load with torch.load('model.pt')

The step above can be reproduced inside a kaggle notebook:

model = MBartForConditionalGeneration.load_pretrained("facebook/mbart-large-50-many-to-many-mmt")
model.save_pretrained('./my-model/')
torch.save(model, 'model.pt')
%time model = MBartForConditionalGeneration.from_pretrained("./my-model/")
%time torch_model = torch.load('model.pt')

We will notice that loading with from_pretrained (step 4) is significantly slower than torch.load (step 5); the former takes over 1 minute and the latter just a few seconds (or around 20s if it hasn’t been previously loaded in memory; see notebook).

Expected behavior

The model should take less than 1 minute to load if it has already been cached (see step 1)

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

0reactions
github-actions[bot]commented, May 9, 2021

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Loading mbart-large-50-one-to-many-mmt is very slow - Reddit
Loading mbart-large-50-one-to-many-mmt is very slow ... My computer ether freezes or it takes 15-20 minutes to load the model.
Read more >
facebook/mbart-large-50-many-to-many-mmt - Hugging Face
This model is a fine-tuned checkpoint of mBART-large-50. mbart-large-50-many-to-many-mmt is fine-tuned for multilingual machine translation.
Read more >
Getting CUDA error when trying to train MBART Model
I have tried decreasing batch size as well as killing all processes on the GPU to prevent this error but I cannot seem...
Read more >
dl-translate - Python Package Health Analysis | Snyk
TranslationModel() # Slow when you load it for the first time text_hi = "संयुक्त ... TranslationModel("facebook/mbart-large-50-many-to-many-mmt").
Read more >
DL Translate: User Guide
TranslationModel() # Slow when you load it for the first time text_hi = "संयुक्त ... TranslationModel("facebook/mbart-large-50-many-to-many-mmt").
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found