TemporalFusionTransformer predicting with mode "raw" results in RuntimeError: Sizes of tensors must match except in dimension 0
See original GitHub issue- PyTorch-Forecasting version: 0.8.4
- PyTorch version: 1.8.1
- Python version: 3.8
- Operating System: Ubuntu 20.04.2 LTS
Expected behavior
In order to generate the interpretation plots of a Temporal Fusion Transformer model, I first try to generate the raw predictions using
# Generate TimeSeriesDataSets and DataLoaders
train_ts_dataset = TimeSeriesDataSet(
train_dataset,
time_idx=feature_config["time_index"],
target=feature_config["target"],
group_ids=feature_config["groups"],
min_encoder_length=dataset_config["min_encoder_length"], # 7
max_encoder_length=dataset_config["max_encoder_length"], # 28
min_prediction_idx=1,
min_prediction_length=dataset_config["min_prediction_length"], # 1
max_prediction_length=dataset_config["max_prediction_length"], # 14
static_categoricals=feature_config["categorical"]["static"],
static_reals=feature_config["real"]["static"],
time_varying_known_categoricals=feature_config["categorical"]["dynamic"][
"known"
],
time_varying_unknown_categoricals=feature_config["categorical"]["dynamic"][
"unknown"
],
time_varying_known_reals=feature_config["real"]["dynamic"]["known"],
time_varying_unknown_reals=feature_config["real"]["dynamic"]["unknown"],
target_normalizer=GroupNormalizer(groups=[], transformation="softplus"),
allow_missings=True,
categorical_encoders={
col: NaNLabelEncoder(add_nan=True, warn=False)
for col in set(
feature_config["categorical"]["static"] +
feature_config["categorical"]["dynamic"]["known"] +
feature_config["categorical"]["dynamic"]["unknown"] +
feature_config["groups"]
)
},
)
test_ts_dataset = TimeSeriesDataSet.from_dataset(
train_ts_dataset, test_dataset, stop_randomization=True, predict_mode=True
)
test_dataloader = test_ts_dataset.to_dataloader(
train=False, batch_size=batch_size * 10, num_workers=num_workers
)
tft = TemporalFusionTransformer.from_dataset(
train_ts_dataset,
loss=QuantileLoss(quantiles=model_config["quantiles"]),
output_size=len(model_config["quantiles"]),
**model_config["hyperparameters"],
)
... # Train TFT on training set etc...
# Generate raw predictions.
raw_predictions = tft.predict(test_dataloader, mode="raw")
Actual behavior
The predict
method however raises an error when using mode="raw"
:
...
raw_predictions = tft.predict(dataloader, mode="raw")
File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py", line 982, in predict
output = _concatenate_output(output)
File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py", line 84, in _concatenate_output
output_cat[name] = _torch_cat_na([out[name] for out in output])
File "/home/ubuntu/.local/share/virtualenvs/tmp-XVr6zr33/lib/python3.8/site-packages/pytorch_forecasting/models/base_model.py", line 62, in _torch_cat_na
return torch.cat(x, dim=0)
RuntimeError: Sizes of tensors must match except in dimension 0. Got 40 and 35 in dimension 3 (The offending index is 1)
Note that I have used the same dataloader earlier in a predict
call without using mode="raw"
and that works perfectly fine.
The same error was raised someplace else as documented here https://github.com/jdb78/pytorch-forecasting/issues/85 and that was fixed using https://github.com/jdb78/pytorch-forecasting/pull/108/files. Could it be this padded_stack
function should also be used someplace else, or is something else going on? Perhaps somewhere in _concatenate_output()
, which is called in this line: https://github.com/jdb78/pytorch-forecasting/blob/master/pytorch_forecasting/models/base_model.py#L987 ?
Please let me know whether it is indeed a bug, or if I am doing something wrong 🙏 Thank you!
Issue Analytics
- State:
- Created 2 years ago
- Reactions:5
- Comments:21 (7 by maintainers)
It seems to be related to attention (n_batches x n_decoder_steps (that attend) x n_attention_heads x n_timesteps (to which is attended)) concatenation. Seems like there is an issue with the concatenation logic. Will work on a fix
tft = TemporalFusionTransformer.from_dataset( training, learning_rate=0.001, hidden_size=16, attention_head_size=1, dropout=0.1,
hidden_continuous_size=8, output_size=7, loss=QuantileLoss(), reduce_on_plateau_patience=4, optimizer=‘adam’ )