Training with fp16 precision gives nan in Longt5
See original GitHub issueSystem Info
transformers
version: 4.10.0.dev0- Platform: Linux-3.10.0-1160.62.1.el7.x86_64-x86_64-with-glibc2.17
- Python version: 3.8.13
- PyTorch version (GPU?): 1.9.0+cu111 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: yes
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
I’m currently running the scrolls_benchmark. I’m interested to see the performance of longt5 model on scrolls, so I changed the model name to google/long-t5-tglobal-base and run training with fp16 enabled (If I run with fp32, I get CUDA OOM errors). However, the output loss is always nan. I googled for fixes and found this post: t5-fp16-fixed. I searched in the transformers repo and found that the modelling_longt5 file doesn’t seem to incorporate the clamp_value
change. I wonder if this is the problem that fp16 is not working in longt5? And if so, is there a way to fix it by a similar approach like what you guys have done for t5? Thank you very much!
fyi: You probably noticed that the transformers version is 4.10.0 which does not have longt5. I manually added the longt5 files in a forked scrolls repo here longt5_folder. It indeed works properly under a small parameter setting.
Expected behavior
longt5 model not producing nan loss on fp16
Issue Analytics
- State:
- Created a year ago
- Reactions:1
- Comments:13 (3 by maintainers)
Top GitHub Comments
In general T5 just doesn’t work well with
fp16
since it was trained on bfloat16 and it seems like the model requires quite large values which are not supported by fp16. See: https://github.com/huggingface/transformers/issues/14189Note that LongT5, MT5, T5, ByT5 all use the same architecture more or less
Based on my observations, using the
clamp_value
fix produces much worse results than just using fp32 under same configurations. And with the further comment from @patrickvonplaten , I realized that this fix also causes the training to take more time (I need around 1hr to generate predictions on test set with this fix, while using fp32 I only need around 20-30min).