Error: TimeSeriesDataSet with target list
See original GitHub issue- PyTorch-Forecasting version: 0.7.0
- PyTorch version: 1.7.0
- Python version: 3.8.5
- Operating System: Ubuntu 18.04
Hello,
I’m facing an error using TimeSeriesDataSet when trying to setup the “target” parameter with a list:
max_prediction_length = 3 # 3 months
max_encoder_length = 12
training_cutoff = data["time_idx"].max() - max_prediction_length
training = TimeSeriesDataSet(
data[lambda x: x.time_idx <= training_cutoff],
time_idx="time_idx",
# target="purchase_item_A", # No error, works great!
target=["purchase_item_A", "purchase_item_B"], # Error
group_ids=["client_id"],
min_encoder_length=3,
max_encoder_length=max_encoder_length,
min_prediction_length=1,
max_prediction_length=max_prediction_length,
static_categoricals=["client_id"],
time_varying_known_reals=["time_idx"],
time_varying_unknown_reals=["purchase_item_A", "purchase_item_B"],
target_normalizer=GroupNormalizer(
groups=["client_id"], transformation="softplus"
),
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
allow_missings=True,
)
It works fine setting the target with just one column name, but when using a 2-element list it throws:
KeyError: '__target__'
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-63-932023d62ce0> in <module>
4 training_cutoff = data["time_idx"].max() - max_prediction_length
5
----> 6 training = TimeSeriesDataSet(
7 data[lambda x: x.time_idx <= training_cutoff],
8 time_idx="time_idx",
~/anaconda3/envs/ai4/lib/python3.8/site-packages/pytorch_forecasting/data/timeseries.py in __init__(self, data, time_idx, target, group_ids, weight, max_encoder_length, min_encoder_length, min_prediction_idx, min_prediction_length, max_prediction_length, static_categoricals, static_reals, time_varying_known_categoricals, time_varying_known_reals, time_varying_unknown_categoricals, time_varying_unknown_reals, variable_groups, dropout_categoricals, constant_fill_strategy, allow_missings, add_relative_time_idx, add_target_scales, add_encoder_length, target_normalizer, categorical_encoders, scalers, randomize_length, predict_mode)
322
323 # preprocess data
--> 324 data = self._preprocess_data(data)
325
326 # create index
~/anaconda3/envs/ai4/lib/python3.8/site-packages/pytorch_forecasting/data/timeseries.py in _preprocess_data(self, data)
481 data["__time_idx__"] = data[self.time_idx] # save unscaled
482 assert "__target__" not in data.columns, "__target__ is a protected column and must not be present in data"
--> 483 data["__target__"] = data[self.target]
484 if self.weight is not None:
485 data["__weight__"] = data[self.weight]
~/anaconda3/envs/ai4/lib/python3.8/site-packages/pandas/core/frame.py in __setitem__(self, key, value)
3042 else:
3043 # set column
-> 3044 self._set_item(key, value)
3045
3046 def _setitem_slice(self, key: slice, value):
~/anaconda3/envs/ai4/lib/python3.8/site-packages/pandas/core/frame.py in _set_item(self, key, value)
3119 self._ensure_valid_index(value)
3120 value = self._sanitize_column(key, value)
-> 3121 NDFrame._set_item(self, key, value)
3122
3123 # check if we are modifying a copy
~/anaconda3/envs/ai4/lib/python3.8/site-packages/pandas/core/generic.py in _set_item(self, key, value)
3575 except KeyError:
3576 # This item wasn't present, just insert at end
-> 3577 self._mgr.insert(len(self._info_axis), key, value)
3578 return
3579
~/anaconda3/envs/ai4/lib/python3.8/site-packages/pandas/core/internals/managers.py in insert(self, loc, item, value, allow_duplicates)
1187 value = _safe_reshape(value, (1,) + value.shape)
1188
-> 1189 block = make_block(values=value, ndim=self.ndim, placement=slice(loc, loc + 1))
1190
1191 for blkno, count in _fast_count_smallints(self.blknos[loc:]):
~/anaconda3/envs/ai4/lib/python3.8/site-packages/pandas/core/internals/blocks.py in make_block(values, placement, klass, ndim, dtype)
2720 values = DatetimeArray._simple_new(values, dtype=dtype)
2721
-> 2722 return klass(values, ndim=ndim, placement=placement)
2723
2724
~/anaconda3/envs/ai4/lib/python3.8/site-packages/pandas/core/internals/blocks.py in __init__(self, values, placement, ndim)
128
129 if self._validate_ndim and self.ndim and len(self.mgr_locs) != len(self.values):
--> 130 raise ValueError(
131 f"Wrong number of items passed {len(self.values)}, "
132 f"placement implies {len(self.mgr_locs)}"
ValueError: Wrong number of items passed 2, placement implies 1
I read that it supports a list, so I’m wondering what I could do to fix this or workaround this problem, please.
Thanks in advance!
Issue Analytics
- State:
- Created 3 years ago
- Comments:13 (6 by maintainers)
Top Results From Across the Web
tft unable to set target to a list of strings (multiple targets) #542
TimeSeriesDataSet.html the target parameter can be set to a list of strings indicating multiple variables for prediction
Read more >TimeSeriesDataSet — pytorch-forecasting documentation
List of target normalizers aligned with target_names . variable_to_group_mapping. Mapping from categorical variables to variables in input data.
Read more >Understanding the TimeSeriesDataSet in pytorch forecasting
validation data is the last max_prediction_length data points use as targets for each time series (which mean validation data are the last ...
Read more >Class TimeSeriesDataset (1.20.0) | Python client library
Managed time series dataset resource for Vertex AI ... This arg is not for specifying the annotation name or the training target of...
Read more >How to Convert a Time Series to a Supervised Learning ...
How to develop a function to transform a time series dataset into a ... A supervised learning problem is comprised of input patterns...
Read more >
Top Related Medium Post
No results found
Top Related StackOverflow Question
No results found
Troubleshoot Live Code
Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free
Top Related Reddit Thread
No results found
Top Related Hackernoon Post
No results found
Top Related Tweet
No results found
Top Related Dev.to Post
No results found
Top Related Hashnode Post
No results found
that did the trick, thanks
You are using the GroupNormalizer which can handle only one target. I suggest to use the MultiNormalizer or leave the argument empty to determine it automaticaly.