[BUG] fp16 weights in model_states.pt does not match fp32 weights extracted using zero_to_fp32.py
See original GitHub issueHi, after training the checkpoint folder contains one model_states.pt
file and multiple optim_states.pt
files. One can use zero_to_fp32.py
here to extract fp32 weights from the optim_states.pt files. As I checked when using Zero Stage 2, fp16 weights (the weights in key module
) in model_states.pt
does not match fp32 weights extracted using zero_to_fp32.py, not precision different, but very different. @stas00 @tjruwase .
Update of the Bug:
I finally found that the script zero_to_fp32.py
only works if the parameter groups passed into optimizer has only one param group. If there are two or more parameter groups, the script leads to either ValueError: consumed 66 numels out of 71 - something is wrong
or the extracted fp32 weights does not match fp16 weights. Further digging into deepspeed stage 2 code here, I find that the partitioning of parameters is done in a param group wise manner, i.e. one rank actually gets one slice of all param groups. The script zero_to_fp32.py
makes the assumption that there is only one param group, hence it introduces bug. To resolve the bug, the script has to take consideration of param groups to correctly reconstruct the fp32 parameters.
Here is a reproduce of the bug:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import DeepSpeedPlugin
from deepspeed.ops.adam import FusedAdam
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
return loss
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
return loss
def configure_optimizers(self):
no_decay = ["bias"]
params_decay = [
p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)
]
params_nodecay = [
p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)
]
optim_groups = [
{
"params": params_decay,
"weight_decay": 0.01,
},
{"params": params_nodecay, "weight_decay": 0.0},
]
return torch.optim.SGD(optim_groups, lr=0.1)
#return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
train_data = DataLoader(RandomDataset(32, 64), batch_size=4)
val_data = DataLoader(RandomDataset(32, 64), batch_size=4)
test_data = DataLoader(RandomDataset(32, 64), batch_size=4)
model = BoringModel()
checkpoint_callback = ModelCheckpoint(
dirpath='tests/checkpoints',
save_last=True,
every_n_train_steps=5,
)
trainer = Trainer(
default_root_dir=os.getcwd(),
gpus=-1,
limit_train_batches=1,
limit_val_batches=1,
num_sanity_val_steps=0,
precision=16,
accelerator='ddp',
max_epochs=10,
plugins=[DeepSpeedPlugin(stage=2)],
weights_summary=None,
callbacks=[checkpoint_callback],
)
trainer.fit(model)
trainer.test(model, test_dataloaders=test_data)
if __name__ == '__main__':
run()
After the run, use zero_to_fp32.py
in checkpoint folder to extract the fp32 weights leads to this error ValueError: consumed 66 numels out of 71 - something is wrong
.
Issue Analytics
- State:
- Created 2 years ago
- Comments:10 (10 by maintainers)
Top GitHub Comments
It should be noted that
single_partition_of_fp32_groups
is a list of flat tensors, where each element corresponds to each param group, instead of a single flat tensor. And the padding happens for each param group individually. So I think it makes sense to treat each param group separately. Here is my temporary solution:The tricky thing is one has to infer the
param_shapes
in each param group. But I don’t think this is available without knowing what are the param groups. In my case, I divided theparam_shapes
dictionary based on param groups of my training code (of course following the same param order), and then reconstruct the fp32 weights following your original conversion code for each param group.Hope that may help.
@stas00 Thanks for the reply. I have made detailed modification to the above bug report. Please see the updated one.