[BUG] TFTModel.load_from_checkpoint and .fit() is returning an error.
See original GitHub issueDescribe the bug
First of all we train a model with TFTModel
with 30 epochs. Then, we aim to do transfer learning by re-training the previous model loading it from last checkpoint. Then, we execute the .fit(..,epochs=additional_n_epochs)
but an error occurs:
File "<string>", line 1, in <module>
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 771, in fit
return self.fit_from_dataset(
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 930, in fit_from_dataset
self._train(train_loader, val_loader)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 952, in _train
self.trainer.fit(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
self._call_and_handle_interrupt(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1232, in _run
self._checkpoint_connector.restore_training_state()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 199, in restore_training_state
self.restore_loops()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 293, in restore_loops
raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: You restored a checkpoint with current_epoch=29, but you have set Trainer(max_epochs=5).
To Reproduce
additional_n_epochs=5
my_model = TFTModel.load_from_checkpoint(mymodelname, work_dir=mymodeldir, best=False)
my_model.fit(...,epochs=additional_n_epochs)
Expected behavior
We aim to get a training process departing from the epoch of last checkpoint and continue until the total number of epochs is: my_model.n_epochs + additional_n_epochs
.
System (please complete the following information):
- Python version: 3.9
- darts version 0.18.0
Issue Analytics
- State:
- Created a year ago
- Comments:5 (3 by maintainers)
Top Results From Across the Web
How to properly load checkpoint for testing? #924 - GitHub
Am I doing something wrong here, causing the accuracy returned to be different when I load the checkpoints separately instead of doing .test()...
Read more >Temporal Fusion Transformer (TFT) — darts documentation
This allows to use the TFTModel without having to pass future_covariates to fit() and train() . It gives a value to the position...
Read more >Demand forecasting with the Temporal Fusion Transformer
After training, we can make predictions with predict() . The method allows very fine-grained control over what it returns so that, for example,...
Read more >darts [python]: Datasheet - Package Galaxy
Description: A python library for easy manipulation and forecasting of time series. Installation: pip install darts. Last version: 0.22.0 (Download) Homepage ...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Hmm, it might be that we lost automatic support for that with the new PyTorch Lightning versions. Could you try this manual approach to see if it works?
@dennisbader , what’s your opinion, do you think we should rework the way we handle epochs?