map_location for TemporalFusionTransformer.load_from_checkpoint not working?
See original GitHub issueHey there,
great package so far. Much appreciated.
I am running into an issue when trying to load a GPU trained model on a CPU only machine
- PyTorch-Forecasting version: 0.8.2
- PyTorch version: 1.7.1
- Python version: 3.8
- Operating System: Ubuntu / MacOS
Expected behavior
- I trained a TFT model on a machine with GPU.
- Moved the auto generated checkpoint to another machie without GPU
- try to load with:
tft = TemporalFusionTransformer.load_from_checkpoint(model_path, map_location=torch.device('cpu'))
Actual behavior
- still get this error:
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
Using the same code on the machine it was initially trained on works all fine.
Code to reproduce the problem
Training
training = TimeSeriesDataSet(
out[lambda x: x.start_datetime <= end_training_datetime_range],
time_idx='time_idx',
target=VARIABLE,
group_ids=["id"],
categorical_encoders={"series": NaNLabelEncoder(add_nan=True).fit(out.id)},
static_categoricals=["id"],
time_varying_known_reals=[VARIABLES],
time_varying_known_categoricals=[VARIABLES],
time_varying_unknown_reals=[VARIABLE],
max_encoder_length=max_encoder_length,
max_prediction_length=max_prediction_length,
target_normalizer=GroupNormalizer(
groups=["id"],
transformation="softplus"
),
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)
validation = TimeSeriesDataSet.from_dataset(training, out, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=8)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=8)
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=5, verbose=True, mode="min")
lr_logger = LearningRateMonitor()
logger = TensorBoardLogger("lightning_logs") # logging results to a tensorboard
trainer = pl.Trainer(
max_epochs=30,
gpus=1,
gradient_clip_val=0.1,
#limit_train_batches=30,
callbacks=[lr_logger, early_stop_callback],
logger=logger
# auto_lr_find=True
)
tft = TemporalFusionTransformer.from_dataset(
training,
learning_rate=0.03,
hidden_size=24,
attention_head_size=1,
dropout=0.1,
hidden_continuous_size=16,
output_size=7,
loss=QuantileLoss(),
log_interval=2,
reduce_on_plateau_patience=4
)
trainer.fit(
tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader
)
Loading
tft = TemporalFusionTransformer.load_from_checkpoint(model_path, map_location=torch.device('cpu'))
Traceback
Traceback (most recent call last):
File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1741, in <module>
main()
File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1735, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "/Applications/PyCharm CE.app/Contents/helpers/pydev/pydevd.py", line 1135, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "/Applications/PyCharm CE.app/Contents/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "airflow_dags/predictions.py", line 167, in <module>
main()
File "airflow_dags/predictions.py", line 163, in main
create_predictions(**context)
File "airflow_dags/predictions.py", line 37, in create_predictions
tft = TemporalFusionTransformer.load_from_checkpoint(model_path, map_location=torch.device('cpu'))
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/pytorch_lightning/core/saving.py", line 158, in load_from_checkpoint
model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/pytorch_lightning/core/saving.py", line 201, in _load_model_state
model.on_load_checkpoint(checkpoint)
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/pytorch_forecasting/models/base_model.py", line 688, in on_load_checkpoint
self.loss = cloudpickle.loads(checkpoint["loss"])
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/storage.py", line 141, in _load_from_bytes
return torch.load(io.BytesIO(b))
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 595, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 774, in _legacy_load
result = unpickler.load()
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 730, in persistent_load
deserialized_objects[root_key] = restore_location(obj, location)
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 175, in default_restore_location
result = fn(storage, location)
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 151, in _cuda_deserialize
device = validate_cuda_device(location)
File "/Users/dominique.vandamme/code/data-infrastructure/venv/lib/python3.7/site-packages/torch/serialization.py", line 135, in validate_cuda_device
raise RuntimeError('Attempting to deserialize object on a CUDA '
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
Is it possible that some features of the TFT architecture are currently not capture by the storage mapping?
Issue Analytics
- State:
- Created 3 years ago
- Comments:7 (4 by maintainers)
Top Results From Across the Web
loading large model not finished after 16 hours #66787 - GitHub
Bug posted here -> After saving a checkpoint of a large pytorch_forecasting model (10'100 targets ... Model is a temporal fusion transformer.
Read more >Demand forecasting with the Temporal Fusion Transformer
Load data#. First, we need to transform our time series into a pandas dataframe where each row can be identified with a time...
Read more >Temporal Fusion Transformer (TFT) — darts documentation
To load the model from checkpoint, call MyModelClass.load_from_checkpoint() , where MyModelClass is the TorchForecastingModel class that was used (such as ...
Read more >Temporal Fusion Transformer: Time Series Forecasting with ...
According to [2], Temporal Fusion Transformer outperforms all prominent Deep Learning models for time series forecasting. In this article, we briefly ...
Read more >Loading model from checkpoint is not working - Stack Overflow
Posting the answer from comments: experiment.load_state_dict(checkpoint['state_dict']).
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
Just quickly did it myself to get 0.8.3 over the line today.
gave it a try and it seems to be working both CPU -> CPU as well as GPU -> CPU without the cloudpickling. mind if i send a small PR?