question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

[BUG] TFTModel.load_from_checkpoint and .fit() is returning an error.

See original GitHub issue

Describe the bug First of all we train a model with TFTModel with 30 epochs. Then, we aim to do transfer learning by re-training the previous model loading it from last checkpoint. Then, we execute the .fit(..,epochs=additional_n_epochs) but an error occurs:

File "<string>", line 1, in <module>
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
  return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 771, in fit
  return self.fit_from_dataset(
File ".../python3.9/site-packages/darts/utils/torch.py", line 70, in decorator
  return decorated(self, *args, **kwargs)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 930, in fit_from_dataset
  self._train(train_loader, val_loader)
File ".../python3.9/site-packages/darts/models/forecasting/torch_forecasting_model.py", line 952, in _train
  self.trainer.fit(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
  self._call_and_handle_interrupt(
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
  return trainer_fn(*args, **kwargs)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
  results = self._run(model, ckpt_path=self.ckpt_path)
File ".../python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 1232, in _run
  self._checkpoint_connector.restore_training_state()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 199, in restore_training_state
  self.restore_loops()
File ".../python3.9/site-packages/pytorch_lightning/trainer/connectors/checkpoint_connector.py", line 293, in restore_loops
  raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: You restored a checkpoint with current_epoch=29, but you have set Trainer(max_epochs=5).

To Reproduce

additional_n_epochs=5
my_model = TFTModel.load_from_checkpoint(mymodelname, work_dir=mymodeldir, best=False)
my_model.fit(...,epochs=additional_n_epochs)

Expected behavior We aim to get a training process departing from the epoch of last checkpoint and continue until the total number of epochs is: my_model.n_epochs + additional_n_epochs .

System (please complete the following information):

  • Python version: 3.9
  • darts version 0.18.0

Issue Analytics

  • State:open
  • Created a year ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
dennisbadercommented, Jul 22, 2022

Hmm, it might be that we lost automatic support for that with the new PyTorch Lightning versions. Could you try this manual approach to see if it works?

import os
import pytorch_lightning as pl
from darts.models.forecasting.torch_forecasting_model import _get_checkpoint_folder, _get_checkpoint_fname

epochs = 30
additional_epochs = 5
model_name = mymodelname
work_dir = mymodeldir

ckpt_dir = _get_checkpoint_folder(work_dir, model_name)
file_name = _get_checkpoint_fname(work_dir, model_name, best=False)
ckpt_path = os.path.join(ckpt_dir, file_name)

my_model = TFTModel.load_from_checkpoint(mymodelname, work_dir=mymodeldir, best=False)
trainer_params = my_model.trainer_params

# instantiate a PyTorch Lightning trainer and tell it to resume from your last checkpoint
trainer = pl.Trainer(resume_from_checkpoint=ckpt_path, **trainer_params)

my_model.fit(..., epochs=epochs + additional_epochs, trainer=trainer)
0reactions
hrzncommented, Aug 23, 2022

@dennisbader , what’s your opinion, do you think we should rework the way we handle epochs?

Read more comments on GitHub >

github_iconTop Results From Across the Web

How to properly load checkpoint for testing? #924 - GitHub
Am I doing something wrong here, causing the accuracy returned to be different when I load the checkpoints separately instead of doing .test()...
Read more >
Temporal Fusion Transformer (TFT) — darts documentation
This allows to use the TFTModel without having to pass future_covariates to fit() and train() . It gives a value to the position...
Read more >
Demand forecasting with the Temporal Fusion Transformer
After training, we can make predictions with predict() . The method allows very fine-grained control over what it returns so that, for example,...
Read more >
darts [python]: Datasheet - Package Galaxy
Description: A python library for easy manipulation and forecasting of time series. Installation: pip install darts. Last version: 0.22.0 (Download) Homepage ...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found