load_checkpoint nuances
See original GitHub issueI have a few questions about model checkpointing: https://deepspeed.readthedocs.io/en/latest/model-checkpointing.html
I’m trying to figure out how to best integrate deepspeed into that area.
- If we already have code that does checkpointing of the model/optim/scheduler - so in a simplified way we have the basic:
torch.save(self.optimizer.state_dict(), d)
torch.save(self.lr_scheduler.state_dict(), d)
torch.save(self.model.state_dict, d)
where self.model.state_dict
is the “client” model. And then the same for loading.
Now when I call deepspeed.DeepSpeedEngine.save_checkpoint
I get 4 things saved engine/model/optim/scheduler
When it comes to loading it back, I do deepspeed.DeepSpeedEngine.load_checkpoint
- do I need to somehow update our trainer self.scheduler
, self.optimizer
from that loaded object? I don’t see an API to do that?
Or would it be simpler to not delegate to DS any savings other than its own engine and save model/optim/scheduler and restore those separately (since we are doing it anyway if the trainer is not running under DeepSpeed).
To exemplify with code:
We start with:
model, optimizer, _, lr_scheduler = deepspeed.initialize(...)
self.deepspeed = model # DeepSpeedEngine object
self.model = model.module
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
So the new saving code would be:
torch.save(self.optimizer.state_dict(), d)
torch.save(self.lr_scheduler.state_dict(), d)
torch.save(self.model.state_dict, d)
if self.deepspeed:
self.deepspeed.save_checkpoint(d)
and then on load again leave most of our code intact and just update the engine:
self.optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")...
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")...
self.model = self.model.from_pretrained(model_path)
if self.deepspeed:
self.deepspeed.load_checkpoint(model_path, load_optimizer_states=False, load_lr_scheduler_states=False)
Am I wasting resources saving/loading the separate components, since deepspeed will have to do it anyway? I’m asking since our code is spread around and we don’t always load all components together. e.g. sched/optim are loaded separately, so we end up loading the model twice because deepspeed doesn’t separate the components. i.e. we can’t say not to load the model (but can skip loading the sched/optim)
Alternatively, I could just do:
if self.deepspeed:
self.deepspeed.load_checkpoint(model_path, load_optimizer_states=True, load_lr_scheduler_states=True)
else:
self.optimizer.load_state_dict(torch.load(os.path.join(model_path, "optimizer.pt")...
self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")...
self.model = self.model.from_pretrained(model_path)
and if this is done, do we get all the previous variables .e.g self.optimizer
that we assigned at the beginning from deepspeed.initialize
updated to the loaded-from-the-checkpoint values - or do we now somehow have to recreate all those variables?
model, optimizer, _, lr_scheduler = self.deepspeed.somehow_get_each_component_again
self.deepspeed = model
self.model = model.module
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
I hope my question is easy to understand.
If I were to ask it in a different way: what happens on deepspeed.load_checkpoint
and where things go and what needs to be done besides loading the checkpoint. An example would have been very helpful.
- And one more question: we have code that checks whether the saved model dir has saved optim/sched:
and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
and loads them before training. How would you approach that for deepspeed, which filesystem pattern to match to identify that there is a saved DeepSpeed checkpoint that can be loaded?
I typically see a global_step0
folder. Is it always the same, or perhaps you have a discovery function, so that we could do something like:
if deepspeed.has_checkpoint(path):
deepspeed.load_checkpoint(path)
I suppose we could try/except
too, but that’s not very clean if there is/could be an API to do that.
And thinking more about it, since deepspeed.load_checkpoint
will return (None, ?)
if nothing found at path
- will this invalidate the existing deepspeed object?
Thank you.
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (10 by maintainers)
Top GitHub Comments
This is so spot on. I completely agree. I feel the API would be similar to that if we could do them all over again. In fact, I would tweak your proposal ever so lightly
Do you think this could help with the confusion? We will discuss if we can move towards this.
@stas00, sorry this needs to remain opened for a bit longer in order to track it. Not yet got a chance to work on this.