question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Training NMT models?

See original GitHub issue

Hello! Thanks, Tim! I tried bitsandbytes for language models like BLOOM, and it works well.

I have a question about NMT models like NLLB, M2M, mBART, or OPUS. I tried inference for NLLB, and apparently it is not supported. Are any of these models supported for inference, and especially for fine-tuning?

Many thanks!

Issue Analytics

  • State:closed
  • Created 10 months ago
  • Comments:5

github_iconTop GitHub Comments

1reaction
younesbelkadacommented, Nov 23, 2022

Hi @ymoslem Thanks a lot for your message! Indeed, it is not possible for now to train any 8bit model using transformers - we are currently seeing if we can apply LoRA (Low Rank Adaptators) on 8-bit models using transformers but it is under discussion. We’ll keep you posted

1reaction
younesbelkadacommented, Nov 21, 2022

Hi @ymoslem Thanks a lot for your message

1- It is not faster than float16. When loading nllb-200-distilled-600M, with load_in_8bit=True, it takes 17.1 seconds, while with torch_dtype=torch.float16, it takes 17.9 seconds.

Yes this is expected, the 8-bit is currently slower than the fp16 model because the 8-bit quantization is done in two stages. You can check out more about that on the 8-bit integration blogpost.

2- When adding int8_threshold=2.0, I got “an unexpected keyword argument” error. It seems that AutoModelForSeq2SeqLM does not support it.

Yes, please use load_in_8bit_threshold instead. Could you point me to the place you have read that says to use int8_threshold? Maybe the documentation has not been updated

3- GPU consumption seems the similar in both cases; 4871MB with 8bit and 4151MB with float16.

Could you share with me how do you measure that? Note that the memory optimization between fp16 and int8 model really depends on the model size, for nllb-600M you get a memory footpint saving of 1.18, for 3.3B you get a saving of 1.41, etc and it linearly grows with the size of the model. You can check that with this snippet:

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

src_lang = "eng_Latn"
tgt_lang = "spa_Latn"

model_id = "facebook/nllb-200-3.3B"

model = AutoModelForSeq2SeqLM.from_pretrained(model_id,device_map= "auto",load_in_8bit=False, torch_dtype=torch.float16)
model_8bit = AutoModelForSeq2SeqLM.from_pretrained(model_id,device_map= "auto",load_in_8bit=True)
print(model.get_memory_footprint() / model_8bit.get_memory_footprint())
Read more comments on GitHub >

github_iconTop Results From Across the Web

Training efficient neural network models for Firefox Translations
NMT models are trained as language pairs, translating from language A to language B. The training pipeline was designed to train translation ...
Read more >
Tutorial: Neural Machine Translation - seq2seq - Google
However, learning a model based on words has a couple of drawbacks. Because NMT models output a probability distribution over words, they can...
Read more >
NMT | NMnetwork - Neurosequential Network
The Phase I Training Certification program is organized into 10 modules and will take approximately 12 months to complete. Our program involves active ......
Read more >
Training Neural Machine Translation (NMT) Models using ...
Abstract: We implement a Tensor Train layer in the TensorFlow Neural Machine Translation (NMT) model using the t3f library.
Read more >
Scaling neural machine translation to bigger data sets with ...
As NMT models become increasingly successful at learning from large-scale monolingual data (data that is available only in a single ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found