question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Load optimizer status from checkpoint file.

See original GitHub issue

🚀 Feature

In incremental training, we need to load optimizer status along with weights, and send to trainer to train it. But seems the optimizer is missing after load module from checkpoint file.

ckpt_path = checkpoint_callback.best_model_path
ckpt_model = MyModule.load_from_checkpoint(ckpt_path)

ckpt_model.optimziers() # it is be empty.
trainer.fit(ckpt_model) # it will start a fresh new optimizer instead of reusing the old optimizer status from checkpoint.

Motivation

In our pipeline, we load the checkpoint in CPU side (a spark cluster), and send model to GPU side (a Ray cluster) to do the remote training job. It works with Keras because when I load a compiled kearas model, the optimizer weight is part of it. So I can send this compiled model to GPU side to incrementally train it.

Pitch

When load the lightning module from checkpoint, optimizer status is a part of it, so I can pass the module to trainer to continue train it.

# load model checkpoint in CPU side
ckpt_path = checkpoint_callback.best_model_path
ckpt_model = MyModule.load_from_checkpoint(ckpt_path, include_optimizer=True)
ckpt_model.optimziers() # it should show a loaded optimizer from checkpoing

# send to GPU cluster for training
# ckpt_model = deserialize(serialize(ckpt_model))

# start train job in GPU cluster.
trainer.fit(ckpt_model)  # it train with old optimizer status from checkpoint.

Alternatives

Additional context


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @awaelchli @ananthsub @ninginthecloud @rohitgr7

Issue Analytics

  • State:closed
  • Created 2 years ago
  • Reactions:1
  • Comments:5 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
jjenniferdaicommented, Mar 9, 2022

documentation for restoring training state instead of just model weights: https://pytorch-lightning.readthedocs.io/en/latest/common/checkpointing.html#restoring-training-state

I’ll highlight/make this more explicit in upcoming docs updates

1reaction
ananthsubcommented, Mar 9, 2022

The trainer will load the optimizer states from the checkpoint if you pass the checkpoint path as an argument to trainer.fit .

Does that work for you? @jjenniferdai

Read more comments on GitHub >

github_iconTop Results From Across the Web

Saving and loading a general checkpoint in PyTorch
A common PyTorch convention is to save these checkpoints using the .tar file extension. To load the items, first initialize the model and...
Read more >
Load optimizer status from checkpoint file. #12280 - GitHub
The trainer will load the optimizer states from the checkpoint if you pass the checkpoint path as an argument to trainer.fit . Does...
Read more >
Training checkpoints | TensorFlow Core
Train and checkpoint the model​​ The following training loop creates an instance of the model and of an optimizer, then gathers them into...
Read more >
How to load a checkpoint file in a pytorch model?
To load this checkpoint file, I check and see if the checkpoint file exists and then I load it as well as the...
Read more >
Checkpointing Distributed Models and Optimizer States
The SageMaker model parallelism library provides checkpoint APIs to save and load distributed models and optimizer states.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found