Skip to content

Commit

Permalink
Accept both context and test data as input_df.
Browse files Browse the repository at this point in the history
Signed-off-by: Fayvor Love <[email protected]>
  • Loading branch information
fayvor committed Aug 30, 2024
1 parent 4816ba9 commit 85c46dd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:p-20565:t-8607625792:config.py:<module>:PyTorch version 2.2.2 available.\n"
"INFO:p-26782:t-8661148224:config.py:<module>:PyTorch version 2.2.2 available.\n"
]
}
],
Expand Down Expand Up @@ -523,8 +523,8 @@
],
"source": [
"plot_predictions(\n",
" predictions_df=zeroshot_forecast,\n",
" context_df=data,\n",
" input_df=data,\n",
" exploded_predictions_df=zeroshot_forecast,\n",
" freq=\"h\",\n",
" timestamp_column=timestamp_column,\n",
" channel=target_column,\n",
Expand Down
35 changes: 18 additions & 17 deletions tsfm_public/toolkit/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,10 @@ def plot_ts_forecasting(


def plot_predictions(
test_df: Optional[pd.DataFrame] = None,
input_df: Optional[pd.DataFrame] = None,
predictions_df: Optional[pd.DataFrame] = None,
exploded_predictions_df: Optional[pd.DataFrame] = None,
dset: Optional[Dataset] = None,
context_df: Optional[pd.DataFrame] = None,
model: Optional[PreTrainedModel] = None,
freq: Optional[str] = None,
timestamp_column: Optional[str] = None,
Expand All @@ -223,8 +223,9 @@ def plot_predictions(
"""Utility for plotting forecasts along with context and test data.
Args:
test_df: Test data.
predictions_df: The predictions dataframe, containing timestamp and prediction columns
input_df: The input dataframe from which the predictions are generated, containing timestamp and target columns.
predictions_df: The predictions dataframe, where each row contains starting timestamp and a list of predictions for each target column.
exploded_predictions_df: The predictions dataframe, containing timestamp and predicted target columns.
dset: Dataset.
context_df: Context dataframe, containing timestamp and target columns.
model: The pre-trained TimeseriesModel.
Expand All @@ -242,29 +243,29 @@ def plot_predictions(
num_plots = len(indices)

# possible operations:
if context_df is not None and predictions_df is not None:
if input_df is not None and exploded_predictions_df is not None:
# 1) This is a zero-shot prediction, so no test data. We have context data for the channel (target column).
# We expect the context and predictions to contain the channel
pchannel = f"{channel}_prediction"
if pchannel not in predictions_df.columns:
if pchannel not in exploded_predictions_df.columns:
raise ValueError(f"Predictions dataframe does not contain target column '{pchannel}'.")
if channel not in context_df.columns:
if channel not in input_df.columns:
raise ValueError(f"Context dataframe does not contain target column '{channel}'.")

num_plots = 1
prediction_length = len(predictions_df)
plot_context = len(context_df)
prediction_length = len(exploded_predictions_df)
plot_context = len(input_df)
using_pipeline = True
plot_test_data = False
elif test_df is not None and predictions_df is not None:
# 2) test_df and predictions plus column information is provided
elif input_df is not None and predictions_df is not None:
# 2) input_df and predictions plus column information is provided

if indices is None:
l = len(predictions_df)
indices = np.random.choice(l, size=num_plots, replace=False)
predictions_subset = [predictions_df.iloc[i] for i in indices]

gt_df = test_df.copy()
gt_df = input_df.copy()
gt_df = gt_df.set_index(timestamp_column) # add id column logic here

prediction_length = len(predictions_subset[0][channel])
Expand All @@ -285,7 +286,7 @@ def plot_predictions(
using_pipeline = False
plot_test_data = True
else:
raise RuntimeError("You must provide either test_df and predictions_df, or dset and model, or context_df, predictions_df and target_columns.")
raise RuntimeError("You must provide either input_df and predictions_df, or dset and model, or input_df and exploded_predictions_df.")

if plot_context is None:
plot_context = 2 * prediction_length
Expand Down Expand Up @@ -318,13 +319,13 @@ def plot_predictions(
plot_title = f"Example {indices[i]}"

elif using_pipeline:
ts_y_hat = create_timestamps(predictions_df[timestamp_column].iloc[0], freq=freq, periods=prediction_length)
y_hat = predictions_df[f"{channel}_prediction"]
ts_y_hat = create_timestamps(exploded_predictions_df[timestamp_column].iloc[0], freq=freq, periods=prediction_length)
y_hat = exploded_predictions_df[f"{channel}_prediction"]

# get context
# ts_y = create_timestamps(context_df[timestamp_column].iloc[0], freq=freq, periods=len(context_df))
ts_y = context_df[timestamp_column].values
y = context_df[channel].values
ts_y = input_df[timestamp_column].values
y = input_df[channel].values
border = None
plot_title = f"Forecast for {channel}"

Expand Down

0 comments on commit 85c46dd

Please sign in to comment.