From 39a28bab4f8effbdaba5b56cc32c32c22800611a Mon Sep 17 00:00:00 2001 From: Denys Senkin Date: Tue, 21 Dec 2021 19:33:08 +0100 Subject: [PATCH 1/4] Replace DataFrame of indices with np.recarray --- pytorch_forecasting/data/timeseries.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 497cd1b1..4d154d1f 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -1247,7 +1247,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra len(df_index) > 0 ), "filters should not remove entries all entries - check encoder/decoder lengths and lags" - return df_index + return df_index.to_records(index=False) def filter(self, filter_func: Callable, copy: bool = True) -> "TimeSeriesDataSet": """ @@ -1404,17 +1404,22 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: Returns: Tuple[Dict[str, torch.Tensor], torch.Tensor]: x and y for model """ - index = self.index.iloc[idx] + index = self.index[idx] + # get index data + index_start = index.index_start + index_end = index.index_end + index_sequence_length = index.sequence_length + # get index data - data_cont = self.data["reals"][index.index_start : index.index_end + 1].clone() - data_cat = self.data["categoricals"][index.index_start : index.index_end + 1].clone() - time = self.data["time"][index.index_start : index.index_end + 1].clone() - target = [d[index.index_start : index.index_end + 1].clone() for d in self.data["target"]] - groups = self.data["groups"][index.index_start].clone() + data_cont = self.data["reals"][index_start : index_end + 1].clone() + data_cat = self.data["categoricals"][index_start : index_end + 1].clone() + time = self.data["time"][index_start : index_end + 1].clone() + target = [d[index_start : index_end + 1].clone() for d in self.data["target"]] + groups = self.data["groups"][index_start].clone() if self.data["weight"] is None: weight = None else: - weight = self.data["weight"][index.index_start : index.index_end + 1].clone() + weight = self.data["weight"][index_start : index_end + 1].clone() # get target scale in the form of a list target_scale = self.target_normalizer.get_parameters(groups, self.group_ids) if not isinstance(self.target_normalizer, MultiNormalizer): @@ -1422,7 +1427,7 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: # fill in missing values (if not all time indices are specified sequence_length = len(time) - if sequence_length < index.sequence_length: + if sequence_length < index_sequence_length: assert self.allow_missing_timesteps, "allow_missing_timesteps should be True if sequences have gaps" repetitions = torch.cat([time[1:] - time[:-1], torch.ones(1, dtype=time.dtype)]) indices = torch.repeat_interleave(torch.arange(len(time)), repetitions) From 46646c06359df64340b44bda3747557225548309 Mon Sep 17 00:00:00 2001 From: Denys Senkin Date: Wed, 22 Dec 2021 18:08:28 +0100 Subject: [PATCH 2/4] Enable index field (to avoid changes in other files) --- pytorch_forecasting/data/timeseries.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 4d154d1f..36cf990d 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -1247,8 +1247,7 @@ def _construct_index(self, data: pd.DataFrame, predict_mode: bool) -> pd.DataFra len(df_index) > 0 ), "filters should not remove entries all entries - check encoder/decoder lengths and lags" - return df_index.to_records(index=False) - + return df_index.to_records(index=True) def filter(self, filter_func: Callable, copy: bool = True) -> "TimeSeriesDataSet": """ Filter subsequences in dataset. From 08cb6365b810f3e92a48a9a5c07b74549e4f6e60 Mon Sep 17 00:00:00 2001 From: Denys Senkin Date: Sun, 26 Dec 2021 16:11:59 +0100 Subject: [PATCH 3/4] Fix conversion to numpy, when we have numpy already --- pytorch_forecasting/data/timeseries.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index 36cf990d..a7d05507 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -1291,8 +1291,8 @@ def decoded_index(self) -> pd.DataFrame: pd.DataFrame: index that can be understood in terms of original data """ # get dataframe to filter - index_start = self.index["index_start"].to_numpy() - index_last = self.index["index_end"].to_numpy() + index_start = self.index["index_start"] + index_last = self.index["index_end"] index = ( # get group ids in order of index pd.DataFrame(self.data["groups"][index_start].numpy(), columns=self.group_ids) From eb706f90706929e18421c59ef727adc9d052e8df Mon Sep 17 00:00:00 2001 From: Denys Senkin Date: Tue, 28 Dec 2021 19:12:58 +0100 Subject: [PATCH 4/4] Remove whitespaces --- pytorch_forecasting/data/timeseries.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_forecasting/data/timeseries.py b/pytorch_forecasting/data/timeseries.py index a7d05507..dcfdf688 100644 --- a/pytorch_forecasting/data/timeseries.py +++ b/pytorch_forecasting/data/timeseries.py @@ -1404,10 +1404,10 @@ def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: Tuple[Dict[str, torch.Tensor], torch.Tensor]: x and y for model """ index = self.index[idx] - # get index data + # get index data index_start = index.index_start index_end = index.index_end - index_sequence_length = index.sequence_length + index_sequence_length = index.sequence_length # get index data data_cont = self.data["reals"][index_start : index_end + 1].clone()