From c0716a295decce22c6596d76645db3e9933debc0 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 4 Sep 2024 13:50:57 -0400 Subject: [PATCH 1/2] improve docstring, fix boundary location --- tsfm_public/toolkit/visualization.py | 41 ++++++++++++++++------------ 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/tsfm_public/toolkit/visualization.py b/tsfm_public/toolkit/visualization.py index eb52d6a4..3c08c9e9 100644 --- a/tsfm_public/toolkit/visualization.py +++ b/tsfm_public/toolkit/visualization.py @@ -213,7 +213,7 @@ def plot_predictions( freq: Optional[str] = None, timestamp_column: Optional[str] = None, id_columns: Optional[List[str]] = None, - plot_context: Optional[None] = None, + plot_context: Optional[int] = None, plot_dir: str = None, num_plots: int = 10, plot_prefix: str = "valid", @@ -222,22 +222,28 @@ def plot_predictions( ): """Utility for plotting forecasts along with context and test data. + User should pass either: + + - input_df and predictions_df: context will be extracted from input_df, and predictions will be extracted from predictions_df. Predictions_df is expected to have rows containing lists of predictions. + - input_df and exploded_predictions_df: context will be extracted from input_df, and predictions from exploded_predictions_df will be plotted + - dset and model: model will be used to produce predictions from records selected from dset + + Args: - input_df: The input dataframe from which the predictions are generated, containing timestamp and target columns. - predictions_df: The predictions dataframe, where each row contains starting timestamp and a list of predictions for each target column. - exploded_predictions_df: The predictions dataframe, containing timestamp and predicted target columns. - dset: Dataset. - context_df: Context dataframe, containing timestamp and target columns. - model: The pre-trained TimeseriesModel. - freq: Frequency of the time series data - timestamp_column: Name of timestamp column in the dataframe. - id_columns: List of id columns in the dataframe. - plot_context: If True, plot context data along with forecasts. - plot_dir: Directory where plots are saved. - num_plots: Number of subplots to plot in the figure. - plot_prefix: Prefix to put on the plot file names. - channel: Channel (target column or its index) to plot. - indices: List of indices to plot. + input_df (Optional[pd.DataFrame], optional): The input dataframe from which the predictions are generated, containing timestamp and target columns. Defaults to None. + predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, where each row contains starting timestamp and a list of predictions for each target column. Defaults to None. + exploded_predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, containing timestamp and predicted target columns. Defaults to None. + dset (Optional[Dataset], optional): Torch dataset containing the context data to use as input for the model. Defaults to None. + model (Optional[PreTrainedModel], optional): The pre-trained time series model. Defaults to None. + freq (Optional[str], optional): Frequency of the time series data, using Pandas string abbreviations (https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). Defaults to None. + timestamp_column (Optional[str], optional): Name of timestamp column in the dataframe. Defaults to None. + id_columns (Optional[List[str]], optional): (For future use) List of id columns in the dataframe. Defaults to None. + plot_context (Optional[int], optional): Integer representing the number of time points of historical data to plot. Defaults to None. + plot_dir (str, optional): Directory where plots are saved. Defaults to None. + num_plots (int, optional): Number of subplots to plot in the figure. Defaults to 10. + plot_prefix (str, optional): Prefix to put on the plot file names. Defaults to "valid". + channel (Union[int, str], optional): Channel, i.e., target column or its index, to plot. Defaults to None. + indices (List[int], optional): List of indices to plot. If None, random examples will be chosen. Defaults to None. """ if indices is not None: num_plots = len(indices) @@ -317,7 +323,8 @@ def plot_predictions( y = gt_df.loc[ts_index][channel] ts_y = y.index y = y.values - border = ts_y[-prediction_length] + # border = ts_y[-prediction_length] + border = predictions_subset[i][timestamp_column] plot_title = f"Example {indices[i]}" elif using_pipeline: From 74fdcec1a594c9aaef54c9a7f069db9fdf2c11cf Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Mon, 9 Sep 2024 09:17:30 -0400 Subject: [PATCH 2/2] docstring updates --- tsfm_public/toolkit/visualization.py | 33 +++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/tsfm_public/toolkit/visualization.py b/tsfm_public/toolkit/visualization.py index 3c08c9e9..ac67878e 100644 --- a/tsfm_public/toolkit/visualization.py +++ b/tsfm_public/toolkit/visualization.py @@ -224,26 +224,38 @@ def plot_predictions( User should pass either: - - input_df and predictions_df: context will be extracted from input_df, and predictions will be extracted from predictions_df. Predictions_df is expected to have rows containing lists of predictions. - - input_df and exploded_predictions_df: context will be extracted from input_df, and predictions from exploded_predictions_df will be plotted + - input_df and predictions_df: context will be extracted from input_df, and predictions will be extracted from + predictions_df. Predictions_df is expected to have rows containing lists of predictions. + - input_df and exploded_predictions_df: context will be extracted from input_df, and predictions from + exploded_predictions_df will be plotted - dset and model: model will be used to produce predictions from records selected from dset + If exploded_predictions_df is passed, indices and num_plots are ignored, the assumption is that there are only one + set of predictions passed for plotting. Args: - input_df (Optional[pd.DataFrame], optional): The input dataframe from which the predictions are generated, containing timestamp and target columns. Defaults to None. - predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, where each row contains starting timestamp and a list of predictions for each target column. Defaults to None. - exploded_predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, containing timestamp and predicted target columns. Defaults to None. - dset (Optional[Dataset], optional): Torch dataset containing the context data to use as input for the model. Defaults to None. + input_df (Optional[pd.DataFrame], optional): The input dataframe from which the predictions are generated, + containing timestamp and target columns. Defaults to None. + predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, where each row contains starting + timestamp and a list of predictions for each target column. Defaults to None. + exploded_predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, containing timestamp + and predicted target columns. Defaults to None. + dset (Optional[Dataset], optional): Torch dataset containing the context data to use as input for the model. + Defaults to None. model (Optional[PreTrainedModel], optional): The pre-trained time series model. Defaults to None. - freq (Optional[str], optional): Frequency of the time series data, using Pandas string abbreviations (https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). Defaults to None. + freq (Optional[str], optional): Frequency of the time series data, using Pandas string abbreviations + (https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). Defaults to None. timestamp_column (Optional[str], optional): Name of timestamp column in the dataframe. Defaults to None. - id_columns (Optional[List[str]], optional): (For future use) List of id columns in the dataframe. Defaults to None. - plot_context (Optional[int], optional): Integer representing the number of time points of historical data to plot. Defaults to None. + id_columns (Optional[List[str]], optional): (For future use) List of id columns in the dataframe. Defaults to + None. + plot_context (Optional[int], optional): Integer representing the number of time points of historical data to + plot. Defaults to None. plot_dir (str, optional): Directory where plots are saved. Defaults to None. num_plots (int, optional): Number of subplots to plot in the figure. Defaults to 10. plot_prefix (str, optional): Prefix to put on the plot file names. Defaults to "valid". channel (Union[int, str], optional): Channel, i.e., target column or its index, to plot. Defaults to None. - indices (List[int], optional): List of indices to plot. If None, random examples will be chosen. Defaults to None. + indices (List[int], optional): List of indices to plot. If None, random examples will be chosen. Defaults to + None. """ if indices is not None: num_plots = len(indices) @@ -263,6 +275,7 @@ def plot_predictions( plot_context = len(input_df) using_pipeline = True plot_test_data = False + indices = [-1] # indices not used in exploded case elif input_df is not None and predictions_df is not None: # 2) input_df and predictions plus column information is provided