question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[s2s finetune] huge increase in memory demands with --fp16 native amp

See original GitHub issue

While working on https://github.com/huggingface/transformers/issues/8353 I discovered that --fp16 causes a 10x+ increase in gpu memory demands.

e.g. I can run bs=12 w/o --fp16

cd examples/seq2seq
export BS=12; rm -rf distilbart-cnn-12-6; python finetune.py --learning_rate=3e-5 --gpus 1 \
--do_train --do_predict --val_check_interval 0.25 --n_val 500 --num_train_epochs 2 --freeze_encoder \
--freeze_embeds --data_dir cnn_dm --max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps 1 \
--model_name_or_path sshleifer/student_cnn_12_6 --tokenizer_name facebook/bart-large \
--warmup_steps 500 --output_dir distilbart-cnn-12-6

But if I add:

--fp16

(w/ or w/o --fp16_opt_level O1)

I get OOM even with bs=1 on a 8GB card and it barely manages on a 24GB card - I think the increase in memory demand is more than 10x.

The OOM either right away when it does the sanity check step, or after just 10-20 batches - so within a few secs

This is with pytorch-1.6. Same goes for pytorch-1.7 and 1.8-nightly.

I wasn’t able to test --fp16 with pytorch-1.5, since I can’t build apex on ubuntu-20.04. Without --fp16 pytorch-1.5 works the same as pytorch-1.6 gpu memory-wise.

I tested with pytorch-1.5 + apex and there is no problem there. Memory consumption is about half.

Here is the table of the batch sizes that fit into a 8gb rtx-1070 (bigger BS leads to an instant OOM):

bs version
12 pt15
20 pt15+fp16
12 pt16
1 pt16+fp16

If you’d like to reproduce the problem here are the full steps:

# prep library
git clone https://github.com/huggingface/transformers
cd transformers
pip install -e .[dev]
pip install -r examples/requirements.txt
cd examples/seq2seq

# prep data
wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz
tar -xzvf cnn_dm_v2.tgz  # empty lines removed
mv cnn_cln cnn_dm

# run
export BS=12; 
rm -rf distilbart-cnn-12-6
python finetune.py --learning_rate=3e-5 --gpus 1 \
--do_train --do_predict --val_check_interval 0.25 --n_val 500 --num_train_epochs 2 --freeze_encoder \
--freeze_embeds --data_dir cnn_dm --max_target_length 142 --val_max_target_length=142 \
--train_batch_size=$BS --eval_batch_size=$BS --gradient_accumulation_steps 1 \
--model_name_or_path sshleifer/student_cnn_12_6 --tokenizer_name facebook/bart-large \
--warmup_steps 500 --output_dir distilbart-cnn-12-6 

This issue is to track the problem and hopefully finding a solution.

@sshleifer

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:35 (28 by maintainers)

github_iconTop GitHub Comments

5reactions
stas00commented, Dec 15, 2020

a fix has been applied to pytorch-nightly https://github.com/pytorch/pytorch/pull/48696 which fixes https://github.com/pytorch/pytorch/issues/48049 and I verified with pytorch-nightly this issue to no longer leak memory under native amp.

Please note that this change is going to be available in pytorch-1.8 - so until then native amp and transformers aren’t going to play well at times. Until then the solution is to use apex.

edit: good news it seems that pytorch-1.7.1 will have this fix too! https://github.com/pytorch/pytorch/issues/48049#issuecomment-742790722

2reactions
sguggercommented, Nov 17, 2020

Great, thanks a lot for your thorough investigation @stas00 !

Read more comments on GitHub >

github_iconTop Results From Across the Web

Fp16 training with feedforward network slower time and no ...
Hello, I'm doing mixed-precision training (from the native amp in pytorch ... [s2s finetune] huge increase in memory demands with --fp16.
Read more >
Memory and speed - Hugging Face
Memory Efficient Attention. Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory...
Read more >
Search Program - SC22 - Supercomputing
DescriptionThere is an increasing demand to incorporate hybrid environments as part of workflows across edge, cloud, and HPC systems. In a such converging ......
Read more >
Deep Learning Systems - Morgan & Claypool Publishers
large memory capacity and high network and memory bandwidth. During the training process, multiple samples are processed in parallel, improving data.
Read more >
End-to-end dialogue systems with pretrained language models
However, in recent years thanks to the increase of computational power and the mass expansion of deep neural networks, they made a considerable...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found