mt5 getting nans with fp16
See original GitHub issueEnvironment info
transformers
version: 4.4.2- Platform: linux
- Python version: 3.7
- PyTorch version (GPU?): 1.8
- Tensorflow version (GPU?): -
- Using GPU in script?: -
- Using distributed or parallel set-up in script?: -
Who can help
t5: @patrickvonplaten, @patil-suraj
Information
I am using mt5-small model:
- the problem arises when using fp16 with mt5
The tasks I am working on is:
- translation
To reproduce
Steps to reproduce the behavior:
python run_translation.py --model_name_or_path google/mt5-small --do_train --do_eval --source_lang en --target_lang ro --dataset_name wmt16 --dataset_config_name ro-en --output_dir test/tst-translation --per_device_train_batch_size=4 --per_device_eval_batch_size=4 --overwrite_output_dir --predict_with_generate --max_train_samples 100 --fp16
outputs:
***** eval metrics *****
epoch = 3.0
eval_bleu = 0.0039
eval_gen_len = 2.95
eval_loss = nan
eval_mem_cpu_alloc_delta = 4MB
eval_mem_cpu_peaked_delta = 5MB
eval_mem_gpu_alloc_delta = 0MB
eval_mem_gpu_peaked_delta = 1080MB
eval_runtime = 72.1865
eval_samples = 1999
eval_samples_per_second = 27.692
Expected behavior
being able to use fp16 with mt5 models. Thank you very much for your help, this is really crucial for me to be able to run these models with fp16 to be able to fit more data into old GPUs I have access to and I appreciate a lot your help.
Issue Analytics
- State:
- Created 3 years ago
- Comments:11 (4 by maintainers)
Top GitHub Comments
Dear @stas00 I tested more codes, without deepspeed, it works fine with setting the feedforward layer to float32, as suggested in the PR, but the moment I switch to deepspeed I still get nan issue in my codes. I greatly appreciate if you can spare some moments from your precious time and provide me with a suggestion for the case of deepspeed for the same problem. Thank you very much
I also used your debug codes:
Dear @stas00 I tested the code more (without deepspeed) on larger scale and when I train on opus100 (I train on 20 languages of it), after 2000 iterations with mt5-small, after applying the fix, this gets nan still. I will share with you a reproducible code soon. thanks a lot for all the great work.