Skip to content

Commit

Permalink
Merge pull request #130 from ibm-granite/plot_fixes
Browse files Browse the repository at this point in the history
Plot fixes
  • Loading branch information
wgifford authored Sep 10, 2024
2 parents f856e5b + 74fdcec commit fbaa4b5
Showing 1 changed file with 37 additions and 17 deletions.
54 changes: 37 additions & 17 deletions tsfm_public/toolkit/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def plot_predictions(
freq: Optional[str] = None,
timestamp_column: Optional[str] = None,
id_columns: Optional[List[str]] = None,
plot_context: Optional[None] = None,
plot_context: Optional[int] = None,
plot_dir: str = None,
num_plots: int = 10,
plot_prefix: str = "valid",
Expand All @@ -222,22 +222,40 @@ def plot_predictions(
):
"""Utility for plotting forecasts along with context and test data.
User should pass either:
- input_df and predictions_df: context will be extracted from input_df, and predictions will be extracted from
predictions_df. Predictions_df is expected to have rows containing lists of predictions.
- input_df and exploded_predictions_df: context will be extracted from input_df, and predictions from
exploded_predictions_df will be plotted
- dset and model: model will be used to produce predictions from records selected from dset
If exploded_predictions_df is passed, indices and num_plots are ignored, the assumption is that there are only one
set of predictions passed for plotting.
Args:
input_df: The input dataframe from which the predictions are generated, containing timestamp and target columns.
predictions_df: The predictions dataframe, where each row contains starting timestamp and a list of predictions for each target column.
exploded_predictions_df: The predictions dataframe, containing timestamp and predicted target columns.
dset: Dataset.
context_df: Context dataframe, containing timestamp and target columns.
model: The pre-trained TimeseriesModel.
freq: Frequency of the time series data
timestamp_column: Name of timestamp column in the dataframe.
id_columns: List of id columns in the dataframe.
plot_context: If True, plot context data along with forecasts.
plot_dir: Directory where plots are saved.
num_plots: Number of subplots to plot in the figure.
plot_prefix: Prefix to put on the plot file names.
channel: Channel (target column or its index) to plot.
indices: List of indices to plot.
input_df (Optional[pd.DataFrame], optional): The input dataframe from which the predictions are generated,
containing timestamp and target columns. Defaults to None.
predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, where each row contains starting
timestamp and a list of predictions for each target column. Defaults to None.
exploded_predictions_df (Optional[pd.DataFrame], optional): The predictions dataframe, containing timestamp
and predicted target columns. Defaults to None.
dset (Optional[Dataset], optional): Torch dataset containing the context data to use as input for the model.
Defaults to None.
model (Optional[PreTrainedModel], optional): The pre-trained time series model. Defaults to None.
freq (Optional[str], optional): Frequency of the time series data, using Pandas string abbreviations
(https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases). Defaults to None.
timestamp_column (Optional[str], optional): Name of timestamp column in the dataframe. Defaults to None.
id_columns (Optional[List[str]], optional): (For future use) List of id columns in the dataframe. Defaults to
None.
plot_context (Optional[int], optional): Integer representing the number of time points of historical data to
plot. Defaults to None.
plot_dir (str, optional): Directory where plots are saved. Defaults to None.
num_plots (int, optional): Number of subplots to plot in the figure. Defaults to 10.
plot_prefix (str, optional): Prefix to put on the plot file names. Defaults to "valid".
channel (Union[int, str], optional): Channel, i.e., target column or its index, to plot. Defaults to None.
indices (List[int], optional): List of indices to plot. If None, random examples will be chosen. Defaults to
None.
"""
if indices is not None:
num_plots = len(indices)
Expand All @@ -257,6 +275,7 @@ def plot_predictions(
plot_context = len(input_df)
using_pipeline = True
plot_test_data = False
indices = [-1] # indices not used in exploded case
elif input_df is not None and predictions_df is not None:
# 2) input_df and predictions plus column information is provided

Expand Down Expand Up @@ -317,7 +336,8 @@ def plot_predictions(
y = gt_df.loc[ts_index][channel]
ts_y = y.index
y = y.values
border = ts_y[-prediction_length]
# border = ts_y[-prediction_length]
border = predictions_subset[i][timestamp_column]
plot_title = f"Example {indices[i]}"

elif using_pipeline:
Expand Down

0 comments on commit fbaa4b5

Please sign in to comment.