[BUG] DeepSpeed zero_to_fp32.py script ignores some layers while creating FP32 checkpoints from DS ZeRO checkpoints.
See original GitHub issueProblem: Trying to convert DeepSpeed zero checkpoints to PyTorch state_dict
s leads to one layer not being present in the generated state dict. I am using the zero_to_fp32.py
script. I’m trying to train a GPT2 like model, and it looks like the lm_head
(linear layer) of the model is not being correctly included in checkpoints for some reason.
Description
I’m trying to train a GPT2-like model using DeepSpeed (and some code from huggingface/transformers). The model looks something like this:
CustomGPT2LMHeadModel(
(transformer): GPT2Model(
(wte): Embedding(50257, 1024)
(wpe): Embedding(1024, 1024)
(drop): Dropout(p=0.1, inplace=False)
(h): ModuleList(
(0): GPT2Block(
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(dropout): Dropout(p=0.1, inplace=False)
)
)
(1): GPT2Block(
(ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(attn): GPT2Attention(
(c_attn): Conv1D()
(c_proj): Conv1D()
(attn_dropout): Dropout(p=0.1, inplace=False)
(resid_dropout): Dropout(p=0.1, inplace=False)
)
(ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
(mlp): GPT2MLP(
(c_fc): Conv1D()
(c_proj): Conv1D()
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
(ln_f): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
)
(lm_head): Linear(in_features=1024, out_features=50257, bias=False)
)
During the training, after every 0.2
epochs, I attempt to save the model by calling model_engine.save_checkpoint(savedir)
on all ranks - which is exactly what the DeepSpeed documentation tells users to do. The generated checkpoint looks something like this:
├── global_step12
│ ├── mp_rank_00_model_states.pt
│ ├── zero_pp_rank_0_mp_rank_00_optim_states.pt
│ ├── zero_pp_rank_1_mp_rank_00_optim_states.pt
│ ├── zero_pp_rank_2_mp_rank_00_optim_states.pt
│ └── zero_pp_rank_3_mp_rank_00_optim_states.pt
├── latest
└── zero_to_fp32.py
After the training is done, I try running:
python zero_to_fp32.py . pytorch_model.bin
Next, when I try loading this checkpoint in Torch, the weights for the lm_head
layer are not there:
>>> import torch
>>> state_dict = torch.load("pytorch_model.bin")
>>> state_dict.keys()
odict_keys(['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_proj.bias', 'transformer.ln_f.weight', 'transformer.ln_f.bias'])
From the above code snippet, we see that the lm_head
layer is missing from the keys
in the state_dict
generated by using zero_to_fp32.py
.
More interesting info
Now, when a job runs on N
GPUs, there appear to be N+1
checkpoint files (this example was run on 4 GPUs):
mp_rank_00_model_states.pt
zero_pp_rank_0_mp_rank_00_optim_states.pt
zero_pp_rank_1_mp_rank_00_optim_states.pt
zero_pp_rank_2_mp_rank_00_optim_states.pt
zero_pp_rank_3_mp_rank_00_optim_states.pt
Trying to load and see the internals of the first file, it does look like the lm_head
weights are present in the DeepSpeed
checkpoint - which might suggest that something is wrong with the zero_to_fp32.py
script.
>>> import torch
>>> model_states = torch.load("mp_rank_00_model_states.pt")
>>> "lm_head.weight" in model_states['module'].keys()
True
Reproducing this bug
- A minimal working example of this problem can be found in this GitHub Gist.
- Download the
.py
file and the.json
from the gist and install python dependencies fromdependencies.txt
. - Run
deepspeed minimal_reproducible_ds.py
. - Checkpoints are saved within a
results_*
directory. Navigate into the directory and run thezero_to_fp32.py
script. - The generated state_dict/
.bin
file will not contain the weights of thelm_head
.
Expected Behavior
The state_dict created after using DeepSpeed’s zero_to_fp32.py
script should have all layers’
weights, No layers should be omitted from the state_dict
.
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
[WARNING] please install triton==1.0.0 if you want to use sparse attention
sparse_attn ............ [NO] ....... [NO]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
async_io ............... [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/gandiva/rohitd/.venv/lib/python3.6/site-packages/torch']
torch version .................... 1.10.2+cu102
torch cuda version ............... 10.2
torch hip version ................ None
nvcc version ..................... 11.6
deepspeed install path ........... ['/home/gandiva/rohitd/.venv/lib/python3.6/site-packages/deepspeed']
deepspeed info ................... 0.6.1, unknown, unknown
deepspeed wheel compiled w. ...... torch 1.10, cuda 10.2, hip 0.0
System info
- OS: Ubuntu 18.04.6 LTS
- 4 * 16GB v100 GPUs
- Python 3.6.9
Additional information
I understand that there is a Huggingface + DS integration already if you use Huggingface’s Trainer class. However, all of Huggingface’s modules (including the class GPT2LMHead
being used here are just subclasses of torch.nn.Module
- which, I guess, means they should work even in this code sample (which attempts to use the DeepSpeed API directly).
Issue Analytics
- State:
- Created a year ago
- Comments:7 (2 by maintainers)
Top GitHub Comments
Hi @rohitdwivedula @tjruwase I encountered with similar problem and I think I have found where the problem is.
This kind of problems may appear when parameter sharing are used.
I think the missing
lm_head.weight
shares parameters with and is refer to something like word embeddings 's weight (‘transformer.wte.weight’ here), which is a commly used trick among NLP. Thus there is actually only one parameter hold by ‘transformer.wte.weight’ and the other one only hold reference.And deepspeed, when converting checkpoint, recovers fp32 paramters from optimizer, which holds only true parameters, and the references and parameter that hold references are not recovered.
The simplest workaround now is to load the converted checkpoint, and manually set back the missing parameters and their references. e.g.
ckpt['state_dict']['lm_head.weight'] = ckpt['state_dict']['transformer.wte.weight']
@richarddwang, thanks for sharing your great analysis and workaround. It seems we need to extend zero checkpointing to better handle parameter sharing.