Skip to content

Commit

Permalink
Merge pull request #111 from ibm-granite/regression_updates
Browse files Browse the repository at this point in the history
Regression updates
  • Loading branch information
wgifford authored Aug 16, 2024
2 parents 8c79a5e + d7a4327 commit 0263bd5
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 59 deletions.
23 changes: 23 additions & 0 deletions tests/toolkit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ClassificationDFDataset,
ForecastDFDataset,
PretrainDFDataset,
RegressionDFDataset,
ts_padding,
)
from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor
Expand Down Expand Up @@ -103,6 +104,28 @@ def test_pretrain_df_dataset(ts_data):
assert ds[0]["timestamp"] == ts_data.iloc[-1]["time_date"]


def test_regression_df_dataset(ts_data):
ts_data2 = ts_data.copy()
ts_data2["id"] = "B"
ts_data2["val2"] = ts_data2["val2"] + 100
ts_data2 = pd.concat([ts_data, ts_data2], axis=0)

ds = RegressionDFDataset(
ts_data2,
id_columns=["id"],
timestamp_column="time_date",
input_columns=["val"],
target_columns=["val2"],
context_length=5,
)

# Test proper target alignment
np.testing.assert_allclose(ds[0]["target_values"].numpy(), np.asarray([104]))
assert ds[0]["id"] == ("A",)
np.testing.assert_allclose(ds[6]["target_values"].numpy(), np.asarray([204]))
assert ds[6]["id"] == ("B",)


def test_forecasting_df_dataset(ts_data_with_categorical):
prediction_length = 2
static_categorical_columns = ["color", "material"]
Expand Down
182 changes: 123 additions & 59 deletions tsfm_public/toolkit/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@


class BaseDFDataset(torch.utils.data.Dataset):
"""
An abtract class representing a :class: `BaseDFDataset`.
All the datasets that represents data frames should subclass it.
All subclasses should overwrite :meth: `__get_item__`
"""Base dataset for time series models built upon a pandas dataframe
Args:
data_df (DataFrame, required): input data
datetime_col (str, optional): datetime column in the data_df. Defaults to None
x_cols (list, optional): list of columns of X. If x_cols is an empty list, all the columns in the data_df is taken, except the datatime_col. Defaults to an empty list.
y_cols (list, required): list of columns of y. Defaults to an empty list.
seq_len (int, required): the sequence length. Defaults to 1
pred_len (int, required): forecasting horizon. Defaults to 0.
zero_padding (bool, optional): pad zero if the data_df is shorter than seq_len+pred_len
data_df (pd.DataFrame): Underlying pandas dataframe.
id_columns (List[str], optional): List of columns which contain id information to separate distinct time series. Defaults to [].
timestamp_column (Optional[str], optional): Name of the timestamp column. Defaults to None.
group_id (Optional[Union[List[int], List[str]]], optional): _description_. Defaults to None.
x_cols (list, optional): Columns to treat as inputs. If an empty list ([]) all the columns in the data_df are taken, except the timestamp column. Defaults to [].
y_cols (list, optional): Columns to treat as outputs. Defaults to [].
drop_cols (list, optional): List of columns that are dropped to form the X matrix (input). Defaults to [].
context_length (int, optional): Length of historical data used when creating individual examples in the torch dataset. Defaults to 1.
prediction_length (int, optional): Length of prediction (future values). Defaults to 0.
zero_padding (bool, optional): If True, windows of context_length+prediction_length which are too short are padded with zeros. Defaults to True.
stride (int, optional): Stride at which windows are produced. Defaults to 1.
fill_value (Union[float, int], optional): Value used to fill any missing values. Defaults to 0.0.
"""

def __init__(
Expand Down Expand Up @@ -133,19 +134,18 @@ def __getitem__(self, index: int):


class BaseConcatDFDataset(torch.utils.data.ConcatDataset):
"""
An abtract class representing a :class: `BaseConcatDFDataset`.
"""A dataset consisting of a concatenation of other datasets, based on torch ConcatDataset.
Args:
data_df (DataFrame, required): input data
datetime_col (str, optional): datetime column in the data_df. Defaults to None
x_cols (list, optional): list of columns of X. If x_cols is an empty list, all the columns in the data_df is taken, except the datatime_col. Defaults to an empty list.
y_cols (list, required): list of columns of y. Defaults to an empty list.
group_ids (list, optional): list of group_ids to split the data_df to different groups. If group_ids is defined, it will triggle the groupby method in DataFrame. If empty, entire data frame is treated as one group.
seq_len (int, required): the sequence length. Defaults to 1
num_workers (int, optional): the number if workers used for creating a list of dataset from group_ids. Defaults to 1.
pred_len (int, required): forecasting horizon. Defaults to 0.
cls (class, required): dataset class
data_df (pd.DataFrame): Underlying pandas dataframe.
id_columns (List[str], optional): List of columns which contain id information to separate distinct time series. Defaults to [].
timestamp_column (Optional[str], optional): Name of the timestamp column. Defaults to None.
context_length (int, optional): Length of historical data used when creating individual examples in the torch dataset. Defaults to 1.
prediction_length (int, optional): Length of prediction (future values). Defaults to 0.
num_workers (int, optional): (Currently not used) Number of workers. Defaults to 1.
fill_value (Union[float, int], optional): Value used to fill any missing values. Defaults to 0.0.
cls (_type_, optional): The dataset class used to create the underlying datasets. Defaults to BaseDFDataset.
stride (int, optional): Stride at which windows are produced. Defaults to 1.
"""

def __init__(
Expand Down Expand Up @@ -258,17 +258,25 @@ def get_group_data(


class PretrainDFDataset(BaseConcatDFDataset):
"""
A :class: `PretrainDFDataset` is used for pretraining.
"""A dataset used for masked pre-training.
To be updated
Args:
data_df (DataFrame, required): input data
datetime_col (str, optional): datetime column in the data_df. Defaults to None
x_cols (list, optional): list of columns of X. If x_cols is an empty list, all the columns in the data_df is taken, except the datatime_col. Defaults to an empty list.
group_ids (list, optional): list of group_ids to split the data_df to different groups. If group_ids is defined, it will triggle the groupby method in DataFrame. If empty, entire data frame is treated as one group.
seq_len (int, required): the sequence length. Defaults to 1
num_workers (int, optional): the number if workers used for creating a list of dataset from group_ids. Defaults to 1.
data (pd.DataFrame): Underlying pandas dataframe.
id_columns (List[str], optional): List of columns which contain id information to separate distinct time series. Defaults to [].
timestamp_column (Optional[str], optional): Name of the timestamp column. Defaults to None.
target_columns (List[str], optional): List of column names which identify the target channels in the input, these are the
columns that will be predicted. Defaults to [].
context_length (int, optional): Length of historical data used when creating individual examples in the torch dataset. Defaults to 1.
num_workers (int, optional): (Currently not used) Number of workers. Defaults to 1.
stride (int, optional): Stride at which windows are produced. Defaults to 1.
fill_value (Union[float, int], optional): Value used to fill any missing values. Defaults to 0.0.
The resulting dataset returns records (dictionaries) containing:
past_values: tensor of past values of the target columns of length equal to context length
past_observed_mask: tensor indicating which values are observed in the past values tensor
timestamp: the timestamp of the end of the context window
id: a tuple of id values (taken from the id columns) containing the id information of the time series segment
"""

def __init__(
Expand Down Expand Up @@ -345,17 +353,49 @@ def __getitem__(self, index):


class ForecastDFDataset(BaseConcatDFDataset):
"""
A :class: `ForecastDFDataset` used for forecasting.
"""A dataset used for forecasting pretraing and inference
Args:
data_df (DataFrame, required): input data
datetime_col (str, optional): datetime column in the data_df. Defaults to None
x_cols (list, optional): list of columns of X. If x_cols is an empty list, all the columns in the data_df is taken, except the datatime_col. Defaults to an empty list.
group_ids (list, optional): list of group_ids to split the data_df to different groups. If group_ids is defined, it will triggle the groupby method in DataFrame. If empty, entire data frame is treated as one group.
seq_len (int, required): the sequence length. Defaults to 1
num_workers (int, optional): the number if workers used for creating a list of dataset from group_ids. Defaults to 1.
pred_len (int, required): forecasting horizon. Defaults to 0.
data (pd.DataFrame): Underlying pandas dataframe.
id_columns (List[str], optional): List of columns which contain id information to separate distinct time series. Defaults
to [].
timestamp_column (Optional[str], optional): Name of the timestamp column. Defaults to None.
target_columns (List[str], optional): List of column names which identify the target channels in the input, these are the
columns that will be predicted. Defaults to [].
observable_columns (List[str], optional): List of column names which identify the observable channels in the input.
Observable channels are channels which we have knowledge about in the past and future. For example, weather
conditions such as temperature or precipitation may be known or estimated in the future, but cannot be
changed. Defaults to [].
control_columns (List[str], optional): List of column names which identify the control channels in the input. Control
channels are similar to observable channels, except that future values may be controlled. For example, discount
percentage of a particular product is known and controllable in the future. Defaults to [].
conditional_columns (List[str], optional): List of column names which identify the conditional channels in the input.
Conditional channels are channels which we know in the past, but do not know in the future. Defaults to [].
static_categorical_columns (List[str], optional): List of column names which identify categorical-valued channels in the
input which are fixed over time. Defaults to [].
context_length (int, optional): Length of historical data used when creating individual examples in the torch dataset.
Defaults to 1.
prediction_length (int, optional): Length of the future forecast. Defaults to 1.
num_workers (int, optional): (Currently not used) Number of workers. Defaults to 1.
frequency_token (Optional[int], optional): An integer representing the frequency of the data. Please see for an example of
frequency token mappings. Defaults to None.
autoregressive_modeling (bool, optional): (Experimental) If False, any target values in the context window are masked and
replaced by 0. If True, the context window contains all the historical target information. Defaults to True.
stride (int, optional): Stride at which windows are produced. Defaults to 1.
fill_value (Union[float, int], optional): Value used to fill any missing values. Defaults to 0.0.
The resulting dataset returns records (dictionaries) containing:
past_values: tensor of past values of the target columns of length equal to context length (context_length x number of features)
past_observed_mask: tensor indicating which values are observed in the past values tensor (context_length x number of features)
future_values: tensor of future values of the target columns of length equal to prediction length (prediction_length x number of features)
future_observed_mask: tensor indicating which values are observed in the future values tensor (prediction_length x number of features)
freq_token: tensor containing the frequency token (scalar)
static_categorical_features: tensor of static categorical features (1 x len(static_categorical_columns))
timestamp: the timestamp of the end of the context window
id: a tuple of id values (taken from the id columns) containing the id information of the time series segment
where number of features is the total number of columns specified in target_columns, observable_columns, control_columns,
conditional_columns
"""

def __init__(
Expand Down Expand Up @@ -510,17 +550,29 @@ def __len__(self):


class RegressionDFDataset(BaseConcatDFDataset):
"""
A :class: `RegressionDFDataset` used for regression.
"""A dataset used for forecasting pretraing and inference
Args:
data_df (DataFrame, required): input data
datetime_col (str, optional): datetime column in the data_df. Defaults to None
input_columns (list, optional): list of columns of X. If x_cols is an empty list, all the columns in the data_df is taken, except the datatime_col. Defaults to an empty list.
output_columns (list, required): list of columns of y. Defaults to an empty list.
id_columns (list, optional): List of columns that specify ids in the dataset. list of group_ids to split the data_df to different groups. If group_ids is defined, it will triggle the groupby method in DataFrame. If empty, entire data frame is treated as one group.
context_length (int, required): the sequence length. Defaults to 1
num_workers (int, optional): the number if workers used for creating a list of dataset from group_ids. Defaults to 1.
data (pd.DataFrame): Underlying pandas dataframe.
id_columns (List[str], optional): List of columns which contain id information to separate distinct time series. Defaults
to [].
timestamp_column (Optional[str], optional): Name of the timestamp column. Defaults to None.
input_columns (List[str], optional): List of columns to use as inputs to the regression
target_columns (List[str], optional): List of column names which identify the target channels in the input, these are the
columns that will be predicted. Defaults to [].
context_length (int, optional): Length of historical data used when creating individual examples in the torch dataset.
Defaults to 1.
num_workers (int, optional): (Currently not used) Number of workers. Defaults to 1.
stride (int, optional): Stride at which windows are produced. Defaults to 1.
fill_value (Union[float, int], optional): Value used to fill any missing values. Defaults to 0.0.
The resulting dataset returns records (dictionaries) containing:
past_values: tensor of past values of the target columns of length equal to context length (context_length x len(input_columns))
past_observed_mask: tensor indicating which values are observed in the past values tensor (context_length x len(input_columns))
target_values: tensor of future values of the target columns of length equal to prediction length (prediction_length x len(target_columns))
static_categorical_features: tensor of static categorical features (1 x len(static_categorical_columns))
timestamp: the timestamp of the end of the context window
id: a tuple of id values (taken from the id columns) containing the id information of the time series segment
"""

def __init__(
Expand Down Expand Up @@ -623,16 +675,28 @@ def __getitem__(self, index):


class ClassificationDFDataset(BaseConcatDFDataset):
"""
A dataset for use with time series classification.
"""A dataset used for forecasting pretraing and inference
Args:
data_df (DataFrame, required): input data
datetime_col (str, optional): datetime column in the data_df. Defaults to None
x_cols (list, optional): list of columns of X. If x_cols is an empty list, all the columns in the data_df is taken, except the datatime_col. Defaults to an empty list.
y_cols (list, required): list of columns of y. Defaults to an empty list.
group_ids (list, optional): list of group_ids to split the data_df to different groups. If group_ids is defined, it will triggle the groupby method in DataFrame. If empty, entire data frame is treated as one group.
seq_len (int, required): the sequence length. Defaults to 1
num_workers (int, optional): the number if workers used for creating a list of dataset from group_ids. Defaults to 1.
data (pd.DataFrame): Underlying pandas dataframe.
id_columns (List[str], optional): List of columns which contain id information to separate distinct time series. Defaults
to [].
timestamp_column (Optional[str], optional): Name of the timestamp column. Defaults to None.
input_columns (List[str], optional): List of columns to use as inputs to the regression
label_column (str, optional): List of column names which identify the label of the time series. Defaults to "label".
context_length (int, optional): Length of historical data used when creating individual examples in the torch dataset.
Defaults to 1.
num_workers (int, optional): (Currently not used) Number of workers. Defaults to 1.
stride (int, optional): Stride at which windows are produced. Defaults to 1.
fill_value (Union[float, int], optional): Value used to fill any missing values. Defaults to 0.0.
The resulting dataset returns records (dictionaries) containing:
past_values: tensor of past values of the target columns of length equal to context length (context_length x len(input_columns))
past_observed_mask: tensor indicating which values are observed in the past values tensor (context_length x len(input_columns))
target_values: tensor containing the label (scalar)
static_categorical_features: tensor of static categorical features (1 x len(static_categorical_columns))
timestamp: the timestamp of the end of the context window
id: a tuple of id values (taken from the id columns) containing the id information of the time series segment
"""

def __init__(
Expand Down

0 comments on commit 0263bd5

Please sign in to comment.