issue with loading pretrained model using DeepSpeed Zero Stage 3
See original GitHub issueSystem Info
- `transformers` version: 4.19.0.dev0
- Platform: Linux-5.4.0-90-generic-x86_64-with-glibc2.29
- Python version: 3.8.10
- Huggingface_hub version: 0.5.1
- PyTorch version (GPU?): 1.12.0.dev20220505+cu113 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: yes
- Using distributed or parallel set-up in script?: yes (deepspeed zero stage-3)
Who can help?
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examples
folder (such as GLUE/SQuAD, …) - My own task or dataset (give details below)
Reproduction
Steps to reproduce the behaviour:
- Official
run_glue.py
script - Below ZERO Stage-3 Config
zero3_config.json
:
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto",
"torch_adam": true,
"adam_w_mode": true
}
},
"scheduler": {
"type": "WarmupDecayLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"total_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
- bash script to run the finetuning of
bert-base-uncased
on MRPC dataset using ZERO Stage-3.
#!/bin/bash
time torchrun --nproc_per_node=2 run_glue.py \
--task_name "mrpc" \
--max_seq_len 128 \
--model_name_or_path "bert-base-uncased" \
--output_dir "./glue/mrpc_deepspeed_stage3_trainer" \
--overwrite_output_dir \
--do_train \
--evaluation_strategy "epoch" \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--gradient_accumulation_steps 1 \
--learning_rate 2e-5 \
--weight_decay 0.0 \
--max_grad_norm 1.0 \
--num_train_epochs 3 \
--lr_scheduler_type "linear" \
--warmup_steps 50 \
--logging_steps 100 \
--fp16 \
--fp16_full_eval \
--optim "adamw_torch" \
--report_to "wandb" \
--deepspeed "zero3_config.json"
- Relevant output snippets. The first one shows the weird behaviour wherein the model isn’t being properly initialized with the pretrained weights. The second shows the eval metrics showing the random performance.
Expected behavior
Model being properly initialized with the pretrained weights when using DeepSpeed ZERO Stage-3. This should resolve the bad model performance being observed.
Issue Analytics
- State:
- Created a year ago
- Comments:12 (11 by maintainers)
Top Results From Across the Web
DeepSpeed Integration - Hugging Face
DeepSpeed ZeRO Inference supports ZeRO stage 3 with ZeRO-Infinity. ... one such example is when loading pretrained model weights in from_pretrained .
Read more >Zero Redundancy Optimizer - DeepSpeed
Training this model without ZeRO fails with an out-of-memory (OOM) error as shown below: A key reason why this model does not fit...
Read more >ZeRO — DeepSpeed 0.8.0 documentation - Read the Docs
Enable offloading of model parameters to CPU or NVMe. This frees up GPU memory for larger models or batch sizes. Valid only with...
Read more >FullyShardedDataParallel — PyTorch 1.13 documentation
A wrapper for sharding Module parameters across data parallel workers. This is inspired by Xu et al. as well as the ZeRO Stage...
Read more >ZeRO & Fastest BERT: Increasing the scale and speed of ...
In this webinar, the DeepSpeed team will discuss what DeepSpeed is, how to use it with your existing PyTorch models, and advancements in...
Read more >Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start FreeTop Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Top GitHub Comments
Hello @stas00, yes the above PR solves this issue. Thank you 😄 . Below are the plots finetuning
microsoft/deberta-v2-xlarge-mnli
(pretrained model has 3 labels) on MRPC (this task has 2 labels) dataset.Thank you, @pacman100
Please try this PR https://github.com/huggingface/transformers/pull/17373