question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

load_checkpoint nuances

See original GitHub issue

I have a few questions about model checkpointing: https://deepspeed.readthedocs.io/en/latest/model-checkpointing.html

I’m trying to figure out how to best integrate deepspeed into that area.

  1. If we already have code that does checkpointing of the model/optim/scheduler - so in a simplified way we have the basic:
torch.save(self.optimizer.state_dict(),  d)
torch.save(self.lr_scheduler.state_dict(), d)
torch.save(self.model.state_dict, d)

where self.model.state_dict is the “client” model. And then the same for loading.

Now when I call deepspeed.DeepSpeedEngine.save_checkpoint I get 4 things saved engine/model/optim/scheduler

When it comes to loading it back, I do deepspeed.DeepSpeedEngine.load_checkpoint - do I need to somehow update our trainer self.scheduler, self.optimizer from that loaded object? I don’t see an API to do that?

Or would it be simpler to not delegate to DS any savings other than its own engine and save model/optim/scheduler and restore those separately (since we are doing it anyway if the trainer is not running under DeepSpeed).

To exemplify with code:

We start with:

model, optimizer, _, lr_scheduler = deepspeed.initialize(...)
self.deepspeed = model # DeepSpeedEngine object
self.model = model.module
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler

So the new saving code would be:

torch.save(self.optimizer.state_dict(),  d)
torch.save(self.lr_scheduler.state_dict(), d)
torch.save(self.model.state_dict, d)
if self.deepspeed:
    self.deepspeed.save_checkpoint(d)

and then on load again leave most of our code intact and just update the engine:

self.optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")...
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")...
self.model = self.model.from_pretrained(model_path)
if self.deepspeed:
    self.deepspeed.load_checkpoint(model_path, load_optimizer_states=False, load_lr_scheduler_states=False)

Am I wasting resources saving/loading the separate components, since deepspeed will have to do it anyway? I’m asking since our code is spread around and we don’t always load all components together. e.g. sched/optim are loaded separately, so we end up loading the model twice because deepspeed doesn’t separate the components. i.e. we can’t say not to load the model (but can skip loading the sched/optim)

Alternatively, I could just do:

if self.deepspeed:
    self.deepspeed.load_checkpoint(model_path, load_optimizer_states=True, load_lr_scheduler_states=True)
else:
    self.optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")...
    self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")...
    self.model = self.model.from_pretrained(model_path)

and if this is done, do we get all the previous variables .e.g self.optimizer that we assigned at the beginning from deepspeed.initialize updated to the loaded-from-the-checkpoint values - or do we now somehow have to recreate all those variables?

model, optimizer, _, lr_scheduler = self.deepspeed.somehow_get_each_component_again
self.deepspeed = model 
self.model = model.module
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler

I hope my question is easy to understand.

If I were to ask it in a different way: what happens on deepspeed.load_checkpoint and where things go and what needs to be done besides loading the checkpoint. An example would have been very helpful.


  1. And one more question: we have code that checks whether the saved model dir has saved optim/sched:
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))

and loads them before training. How would you approach that for deepspeed, which filesystem pattern to match to identify that there is a saved DeepSpeed checkpoint that can be loaded?

I typically see a global_step0 folder. Is it always the same, or perhaps you have a discovery function, so that we could do something like:

if deepspeed.has_checkpoint(path):
    deepspeed.load_checkpoint(path)

I suppose we could try/except too, but that’s not very clean if there is/could be an API to do that.

And thinking more about it, since deepspeed.load_checkpoint will return (None, ?) if nothing found at path - will this invalidate the existing deepspeed object?

Thank you.

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:13 (10 by maintainers)

github_iconTop GitHub Comments

2reactions
tjruwasecommented, Jan 13, 2021

This is so spot on. I completely agree. I feel the API would be similar to that if we could do them all over again. In fact, I would tweak your proposal ever so lightly

deepspeed_engine = deepspeed.initialize(..)
assert deepspeed_engine.has_module() and deepspeed_engine.has_scheduler() and deepspeed_engine.has_optimizer()
model = deepspeed_engine.module 
scheduler = deepspeed_engine.scheduler
optimizer = deepspeed_engine.optimizer

Do you think this could help with the confusion? We will discuss if we can move towards this.

1reaction
tjruwasecommented, Mar 18, 2021

@stas00, sorry this needs to remain opened for a bit longer in order to track it. Not yet got a chance to work on this.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Checkpointing Deep Learning Models in Keras
In this article, you will learn how to checkpoint a deep learning model ... Steps for saving and loading model and weights using...
Read more >
Saving and loading a general checkpoint in PyTorch
A common PyTorch convention is to save these checkpoints using the .tar file extension. To load the items, first initialize the model and...
Read more >
How can I fix an AttributeError while loading checkpoint?
load will load the information from the dict that has been deserialized to the file. This loads as the original dict object, so...
Read more >
eTips - Nuance
How to set Checkpoints for Speech Recognition Users. eTip 3: How AutoText List selections can greatly increase the speed of loading a report....
Read more >
Difference between checkpoint, incremental checkpoint & roll...
the goal of incremental checkpoints is to advance the checkpoint marker so that less redo would have to be read and applied. All...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found