question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

Can't save model when using Rich Progress Bar (multiprocessing)

See original GitHub issue

Describe the bug When using a callback to RichProgressBar I run into the following error when trying to save_model.

TypeError: can't pickle _thread.RLock objects

Without the callback everything works as expected. Is there a workaround with e.g. state_dicts for now?

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
ipykernel_79835/4020755580.py in <module>
----> 1 model.save_model("basemodel.pth.tar")

site-packages/darts/models/forecasting/torch_forecasting_model.py in save_model(self, path)
   1311         # We save the whole object to keep track of everything
   1312         with open(path, "wb") as f_out:
-> 1313             torch.save(self, f_out)
   1314 
   1315         # In addition, we need to use PTL save_checkpoint() to properly save the trainer and model

site-packages/torch/serialization.py in save(obj, f, pickle_module, pickle_protocol, _use_new_zipfile_serialization)
    378         if _use_new_zipfile_serialization:
    379             with _open_zipfile_writer(opened_file) as opened_zipfile:
--> 380                 _save(obj, opened_zipfile, pickle_module, pickle_protocol)
    381                 return
    382         _legacy_save(obj, opened_file, pickle_module, pickle_protocol)

site-packages/torch/serialization.py in _save(obj, zip_file, pickle_module, pickle_protocol)
    587     pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol)
    588     pickler.persistent_id = persistent_id
--> 589     pickler.dump(obj)
    590     data_value = data_buf.getvalue()
    591     zip_file.write_record('data.pkl', data_value, len(data_value))

Issue Analytics

  • State:closed
  • Created a year ago
  • Comments:5 (3 by maintainers)

github_iconTop GitHub Comments

1reaction
gdevos010commented, May 17, 2022

@DeastinY The was the wrong method to suggest. That is what darts is doing under the hood (I’m still new around here). I read through the code and the recommended way to save is to use save_checkpoints=True. I created a simple example).

from darts.datasets import AirPassengersDataset
from darts.models.forecasting.tft_model import TFTModel
from darts.timeseries import TimeSeries

from pytorch_lightning.callbacks import (
    RichProgressBar,
)


ts: TimeSeries = AirPassengersDataset().load()

progress_bar = RichProgressBar()
pl_trainer_kwargs = {"accelerator": "gpu", "gpus": [0], "callbacks": [progress_bar]}

tft_model = TFTModel(
        model_name="TFT",
        batch_size=32,
        input_chunk_length=24,
        output_chunk_length=12,
        n_epochs=5,
        add_relative_index=True,
        pl_trainer_kwargs=pl_trainer_kwargs,
        save_checkpoints=True, # use checkpoint to save them model
        force_reset=True,
    )
tft_model.fit(ts)

# use one of these to load the model
tft_model.load_from_checkpoint(tft_model.model_name, best=False)
model = TFTModel.load_from_checkpoint(tft_model.model_name, best=False)
0reactions
dennisbadercommented, May 19, 2022

model.save_model() ultimately pickles the model, which can fail based on the objects added to the model. The built-in checkpointing only saves the most necessary things: hyperparameters, model weights, optimizer states, …

I just want to add that the recommended way to load the model from checkpoint is (also documented here):

# use this (load_from_checkpoint() returns the loaded object and does not load inplace)
model_loaded = TFTModel.load_from_checkpoint(...)

# not this (inplace does not work)
model_original.load_from_checkpoint(...)

the same also applies to load_model()

Read more comments on GitHub >

github_iconTop Results From Across the Web

need help with multiprocessing · Issue #121 · Textualize/rich
Sorry, I am trying to use this code to have progress reporting on a very long task, but I need to use Pool.starmap()...
Read more >
Show the progress of a Python multiprocessing pool ...
My personal favorite -- gives you a nice little progress bar and completion ETA while things run and commit in parallel. from multiprocessing...
Read more >
Progress Display — Rich 12.6.0 documentation
Rich progress display supports multiple tasks, each with a bar and progress information. You can use this to track concurrent tasks where the...
Read more >
rich progress and multiprocessing - Of Last Importance
How to use rich with python's multiprocessing: What is this? Track progress of long running tasks when using multiprocessing.
Read more >
Progress Bars for Python Multiprocessing Tasks - Lei Mao
Use tqdm for Python Multiprocessing. ... Multiprocessing tasks should also have progress bars to show the progress.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found