From 91cd577763ce8b91b8c49b62504e8145b1f6c4cb Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:55:39 -0500 Subject: [PATCH] check bounds on inner datasets --- tsfm_public/toolkit/dataset.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tsfm_public/toolkit/dataset.py b/tsfm_public/toolkit/dataset.py index fb445391..56698b7d 100644 --- a/tsfm_public/toolkit/dataset.py +++ b/tsfm_public/toolkit/dataset.py @@ -135,6 +135,16 @@ def pad_zero(self, data_df): def __len__(self): return max((len(self.X) - self.context_length - self.prediction_length) // self.stride + 1, 0) + def _check_index(self, index: int) -> int: + if index >= len(self): + raise IndexError("Index exceeds dataset length") + + if index < 0: + if -index > len(self): + raise ValueError("Absolute value of index should not exceed dataset length") + index = len(self) + index + return index + def __getitem__(self, index: int): """ Args: @@ -358,6 +368,8 @@ def __init__( ) def __getitem__(self, index): + index = self._check_index(index) + time_id = index * self.stride seq_x = self.X[time_id : time_id + self.context_length].values ret = { @@ -565,6 +577,7 @@ def apply_masking_specification(self, past_values_tensor: np.ndarray) -> np.ndar def __getitem__(self, index): # seq_x: batch_size x seq_len x num_x_cols + index = self._check_index(index) time_id = index * self.stride @@ -715,6 +728,7 @@ def __init__( def __getitem__(self, index): # seq_x: batch_size x seq_len x num_x_cols + index = self._check_index(index) time_id = index * self.stride seq_x = self.X[time_id : time_id + self.context_length].values @@ -840,6 +854,7 @@ def __init__( def __getitem__(self, index): # seq_x: batch_size x seq_len x num_x_cols + index = self._check_index(index) time_id = index * self.stride seq_x = self.X[time_id : time_id + self.context_length].values