question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Question: Can't Load ZeRO3 model with Engine.load_checkpoint()

See original GitHub issue

I used Engine.save_checkpoint to save my ZeRO3 model_engine . But when I load it with Engine.load_checkpoint(), I encountered runtime error as below:

Traceback (most recent call last):                                                                                                                                                   
  File "train.py", line 329, in <module>                                                                                                                                             
Traceback (most recent call last):                                                                                                                                                   
  File "train.py", line 329, in <module>                                                                                                                                             
    path, state = model_engine.load_checkpoint("Model/tmp", tag="ckpt")                                                                                                              
  File "/home/huangbz/.conda/envs/NLP/lib/python3.6/site-packages/deepspeed/runtime/engine.py", line 1919, in load_checkpoint                                                        
    path, state = model_engine.load_checkpoint("Model/tmp", tag="ckpt")
  File "/home/huangbz/.conda/envs/NLP/lib/python3.6/site-packages/deepspeed/runtime/engine.py", line 1919, in load_checkpoint
    load_module_only=load_module_only)
  File "/home/huangbz/.conda/envs/NLP/lib/python3.6/site-packages/deepspeed/runtime/engine.py", line 1969, in _load_checkpoint
    load_module_only=load_module_only)
  File "/home/huangbz/.conda/envs/NLP/lib/python3.6/site-packages/deepspeed/runtime/engine.py", line 1969, in _load_checkpoint
    strict=load_module_strict)
  File "/home/huangbz/.conda/envs/NLP/lib/python3.6/site-packages/deepspeed/runtime/engine.py", line 1819, in load_module_state_dict
    strict=load_module_strict)
  File "/home/huangbz/.conda/envs/NLP/lib/python3.6/site-packages/deepspeed/runtime/engine.py", line 1819, in load_module_state_dict
    self.module.load_state_dict(state_dict, strict=strict)
    self.module.load_state_dict(state_dict, strict=strict)
  File "/home/huangbz/.conda/envs/NLP/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict
  File "/home/huangbz/.conda/envs/NLP/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1224, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for BartForConditionalGeneration:
        size mismatch for model.encoder.embed_positions.weight: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([1026, 768]).
        ......
        size mismatch for model.decoder.layernorm_embedding.bias: copying a param with shape torch.Size([1]) from checkpoint, the shape in current model is torch.Size([768]).

I’m using deepspeed ZeRO3 to train my bart (implemented by Huggingface’s transformers) with 4 GPUs. (deepspeed --num_gpus=4 train.py --deepspeed --deepspeed_config config/ds_config.json) Here is my code.(to simplify the question, I skip all the training code and only test the load & save function)

    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", default=-1, type=int,
                        help="local_rank for distributed training on gpus")
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
    args.local_rank = int(os.environ['LOCAL_RANK'])

    model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
            [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
            0.01
    }, {
        'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
            0.0
    }]

    model_engine, optimizer, _, scheduler = deepspeed.initialize(args=args, model=model,
                                                                 model_parameters=optimizer_grouped_parameters)
    model_engine.save_checkpoint("Model/tmp", tag="ckpt")
    path, state = model_engine.load_checkpoint("Model/tmp", tag="ckpt")

And here is my ds_config.json

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 3e-05,
            "weight_decay": 0.01
        }
    },

    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-05,
            "warmup_num_steps": 400,
            "total_num_steps": 9000
        }
    },

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": 5e8,
        "stage3_prefetch_bucket_size": 5e8,
        "stage3_param_persistence_threshold": 1e6,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": true
    },
    "gradient_clipping": 0.1,
    "train_micro_batch_size_per_gpu": 2,
    "train_batch_size": 32,
    "wall_clock_breakdown": false
}

I’m new to deepspeed and not familiar with every details about ZeRO3, please help me solve my problem. Thanks a lot!!!

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Comments:13 (13 by maintainers)

github_iconTop GitHub Comments

1reaction
skpigcommented, Oct 7, 2021

No, it doesn’t. The traceback is the same as before with model_parameters=model.parameters()

0reactions
stas00commented, Oct 17, 2021

I totally agree, @skpig, it indeed should be documented

If you’d like you could make a PR adding a note explaining this limitation somewhere in the docstring here: https://github.com/microsoft/DeepSpeed/blob/1fc74cb9c81668b5ff0046446f8004d4cf8dc2d5/deepspeed/runtime/engine.py#L1997

Read more comments on GitHub >

github_iconTop Results From Across the Web

[Deepspeed ZeRO-3] Broken model save on fresh ... - GitHub
The problem with DeepSpeed is that it doesn't currently have a way to save a fp32 checkpoint that can be loaded normally and...
Read more >
Model Checkpointing — DeepSpeed 0.8.0 documentation
Important: under ZeRO3, one cannot load checkpoint with engine.load_checkpoint() right after engine.save_checkpoint() . It is because engine.module is ...
Read more >
DeepSpeed Integration - Hugging Face
DeepSpeed ZeRO-3 can be used for inference as well, since it allows huge models to be loaded on multiple GPUs, which won't be...
Read more >
Unable to load pre-trained model checkpoint with TensorFlow ...
Try changing the fine_tune_checkpoint path in the config file to something like ...
Read more >
DeepSpeed ZeRO-3 Offload
Overview of ZeRO family of technology; ZeRO-3 Offload; Unprecedented model scale; Ease of supporting very large models; Excellent training ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found