question-mark
Stuck on an issue?

Lightrun Answers was designed to reduce the constant googling that comes with debugging 3rd party libraries. It collects links to all the places you might be looking at while hunting down a tough bug.

And, if you’re still stuck at the end, we’re happy to hop on a call to see how we can help out.

map_location for TemporalFusionTransformer.load_from_checkpoint not working?

See original GitHub issue

Hey there,

great package so far. Much appreciated.

I am running into an issue when trying to load a GPU trained model on a CPU only machine

  • PyTorch-Forecasting version: 0.8.2
  • PyTorch version: 1.7.1
  • Python version: 3.8
  • Operating System: Ubuntu / MacOS

Expected behavior

  • I trained a TFT model on a machine with GPU.
  • Moved the auto generated checkpoint to another machie without GPU
  • try to load with:tft = TemporalFusionTransformer.load_from_checkpoint(model_path, map_location=torch.device('cpu'))

Actual behavior

  • still get this error: RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

Using the same code on the machine it was initially trained on works all fine.

Code to reproduce the problem

Training

training = TimeSeriesDataSet(
    out[lambda x: x.start_datetime <= end_training_datetime_range],
    time_idx='time_idx',
    target=VARIABLE,
    group_ids=["id"],
    categorical_encoders={"series": NaNLabelEncoder(add_nan=True).fit(out.id)},
    static_categoricals=["id"],
    time_varying_known_reals=[VARIABLES],
    time_varying_known_categoricals=[VARIABLES],
    time_varying_unknown_reals=[VARIABLE],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
   target_normalizer=GroupNormalizer(
        groups=["id"], 
       transformation="softplus"
    ),  
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

validation = TimeSeriesDataSet.from_dataset(training, out, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=8)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=8)


early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=True, mode="min")
lr_logger = LearningRateMonitor()
logger = TensorBoardLogger("lightning_logs")  # logging results to a tensorboard

trainer = pl.Trainer(
    max_epochs=30,
    gpus=1, 
    gradient_clip_val=0.1,
    #limit_train_batches=30,
    callbacks=[lr_logger, early_stop_callback],
    logger=logger
   # auto_lr_find=True
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=24,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=16,
    output_size=7,
    loss=QuantileLoss(),
    log_interval=2,
    reduce_on_plateau_patience=4
)

trainer.fit(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader
)

Loading

tft = TemporalFusionTransformer.load_from_checkpoint(model_path, map_location=torch.device('cpu'))

Traceback

Traceback (most recent call last):
  File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1741, in <module>
    main()
  File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1735, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1135, in run
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Applications/PyCharm CE.app/Contents/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "airflow_dags/predictions.py", line 167, in <module>
    main()
  File "airflow_dags/predictions.py", line 163, in main
    create_predictions(**context)
  File "airflow_dags/predictions.py", line 37, in create_predictions
    tft = TemporalFusionTransformer.load_from_checkpoint(model_path, map_location=torch.device('cpu'))
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/pytorch_lightning/core/saving.py", line 158, in load_from_checkpoint
    model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/pytorch_lightning/core/saving.py", line 201, in _load_model_state
    model.on_load_checkpoint(checkpoint)
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py", line 688, in on_load_checkpoint
    self.loss = cloudpickle.loads(checkpoint["loss"])
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/storage.py", line 141, in _load_from_bytes
    return torch.load(io.BytesIO(b))
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 595, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 774, in _legacy_load
    result = unpickler.load()
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 730, in persistent_load
    deserialized_objects[root_key] = restore_location(obj, location)
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 175, in default_restore_location
    result = fn(storage, location)
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 151, in _cuda_deserialize
    device = validate_cuda_device(location)
  File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 135, in validate_cuda_device
    raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

Is it possible that some features of the TFT architecture are currently not capture by the storage mapping?

Issue Analytics

  • State:closed
  • Created 3 years ago
  • Comments:7 (4 by maintainers)

github_iconTop GitHub Comments

1reaction
jdb78commented, Jan 31, 2021

Just quickly did it myself to get 0.8.3 over the line today.

1reaction
domplexitycommented, Jan 31, 2021

gave it a try and it seems to be working both CPU -> CPU as well as GPU -> CPU without the cloudpickling. mind if i send a small PR?

Read more comments on GitHub >

github_iconTop Results From Across the Web

loading large model not finished after 16 hours #66787 - GitHub
Bug posted here -> After saving a checkpoint of a large pytorch_forecasting model (10'100 targets ... Model is a temporal fusion transformer.
Read more >
Demand forecasting with the Temporal Fusion Transformer
Load data#. First, we need to transform our time series into a pandas dataframe where each row can be identified with a time...
Read more >
Temporal Fusion Transformer (TFT) — darts documentation
To load the model from checkpoint, call MyModelClass.load_from_checkpoint() , where MyModelClass is the TorchForecastingModel class that was used (such as ...
Read more >
Temporal Fusion Transformer: Time Series Forecasting with ...
According to [2], Temporal Fusion Transformer outperforms all prominent Deep Learning models for time series forecasting. In this article, we briefly ...
Read more >
Loading model from checkpoint is not working - Stack Overflow
Posting the answer from comments: experiment.load_state_dict(checkpoint['state_dict']).
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found