[s2s finetune] huge increase in memory demands with --fp16 native amp
See original GitHub issueWhile working on https://github.com/huggingface/transformers/issues/8353 I discovered that --fp16
causes a 10x+ increase in gpu memory demands.
e.g. I can run bs=12 w/o --fp16
cd examples/seq2seq
export BS=12; rm -rf distilbart-cnn-12-6; python finetune.py --learning_rate=3e-5 --gpus 1 \
--do_train --do_predict --val_check_interval 0.25 --n_val 500 --num_train_epochs 2 --freeze_encoder \
--freeze_embeds --data_dir cnn_dm --max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps 1 \
--model_name_or_path sshleifer/student_cnn_12_6 --tokenizer_name facebook/bart-large \
--warmup_steps 500 --output_dir distilbart-cnn-12-6
But if I add:
--fp16
(w/ or w/o --fp16_opt_level O1
)
I get OOM even with bs=1 on a 8GB card and it barely manages on a 24GB card - I think the increase in memory demand is more than 10x.
The OOM either right away when it does the sanity check step, or after just 10-20 batches - so within a few secs
This is with pytorch-1.6. Same goes for pytorch-1.7 and 1.8-nightly.
I wasn’t able to test --fp16
with pytorch-1.5, since I can’t build apex on ubuntu-20.04. Without --fp16
pytorch-1.5 works the same as pytorch-1.6 gpu memory-wise.
I tested with pytorch-1.5 + apex and there is no problem there. Memory consumption is about half.
Here is the table of the batch sizes that fit into a 8gb rtx-1070 (bigger BS leads to an instant OOM):
bs | version |
---|---|
12 | pt15 |
20 | pt15+fp16 |
12 | pt16 |
1 | pt16+fp16 |
If you’d like to reproduce the problem here are the full steps:
# prep library
git clone https://github.com/huggingface/transformers
cd transformers
pip install -e .[dev]
pip install -r examples/requirements.txt
cd examples/seq2seq
# prep data
wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz
tar -xzvf cnn_dm_v2.tgz # empty lines removed
mv cnn_cln cnn_dm
# run
export BS=12;
rm -rf distilbart-cnn-12-6
python finetune.py --learning_rate=3e-5 --gpus 1 \
--do_train --do_predict --val_check_interval 0.25 --n_val 500 --num_train_epochs 2 --freeze_encoder \
--freeze_embeds --data_dir cnn_dm --max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps 1 \
--model_name_or_path sshleifer/student_cnn_12_6 --tokenizer_name facebook/bart-large \
--warmup_steps 500 --output_dir distilbart-cnn-12-6
This issue is to track the problem and hopefully finding a solution.
Issue Analytics
- State:
- Created 3 years ago
- Comments:35 (28 by maintainers)
a fix has been applied to pytorch-nightly https://github.com/pytorch/pytorch/pull/48696 which fixes https://github.com/pytorch/pytorch/issues/48049 and I verified with pytorch-nightly this issue to no longer leak memory under native amp.
Please note that this change is going to be available in pytorch-1.8 - so until then native amp and transformers aren’t going to play well at times. Until then the solution is to use apex.
edit: good news it seems that pytorch-1.7.1 will have this fix too! https://github.com/pytorch/pytorch/issues/48049#issuecomment-742790722
Great, thanks a lot for your thorough investigation @stas00 !