Skip to content

Commit

Permalink
Run make style for formatting.
Browse files Browse the repository at this point in the history
Signed-off-by: Fayvor Love <[email protected]>
  • Loading branch information
fayvor committed Aug 30, 2024
1 parent 85c46dd commit 7a58c05
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
],
"source": [
"import pathlib\n",
"\n",
"import pandas as pd\n",
"\n",
"from tsfm_public import TimeSeriesForecastingPipeline, TinyTimeMixerForPrediction\n",
"from tsfm_public.toolkit.visualization import plot_predictions"
]
Expand All @@ -57,6 +59,8 @@
],
"source": [
"import tsfm_public\n",
"\n",
"\n",
"tsfm_public.__version__"
]
},
Expand Down Expand Up @@ -490,7 +494,12 @@
],
"source": [
"pipeline = TimeSeriesForecastingPipeline(\n",
" zeroshot_model, timestamp_column=timestamp_column, target_columns=[target_column], explode_forecasts=True, freq=\"h\", id_columns=[]\n",
" zeroshot_model,\n",
" timestamp_column=timestamp_column,\n",
" target_columns=[target_column],\n",
" explode_forecasts=True,\n",
" freq=\"h\",\n",
" id_columns=[],\n",
")\n",
"zeroshot_forecast = pipeline(data)\n",
"zeroshot_forecast.head()"
Expand Down Expand Up @@ -529,7 +538,7 @@
" timestamp_column=timestamp_column,\n",
" channel=target_column,\n",
" indices=[-1],\n",
" num_plots=1\n",
" num_plots=1,\n",
")"
]
},
Expand Down
10 changes: 7 additions & 3 deletions tsfm_public/toolkit/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,9 @@ def plot_predictions(
using_pipeline = False
plot_test_data = True
else:
raise RuntimeError("You must provide either input_df and predictions_df, or dset and model, or input_df and exploded_predictions_df.")
raise RuntimeError(
"You must provide either input_df and predictions_df, or dset and model, or input_df and exploded_predictions_df."
)

if plot_context is None:
plot_context = 2 * prediction_length
Expand Down Expand Up @@ -317,9 +319,11 @@ def plot_predictions(
y = y.values
border = ts_y[-prediction_length]
plot_title = f"Example {indices[i]}"

elif using_pipeline:
ts_y_hat = create_timestamps(exploded_predictions_df[timestamp_column].iloc[0], freq=freq, periods=prediction_length)
ts_y_hat = create_timestamps(
exploded_predictions_df[timestamp_column].iloc[0], freq=freq, periods=prediction_length
)
y_hat = exploded_predictions_df[f"{channel}_prediction"]

# get context
Expand Down

0 comments on commit 7a58c05

Please sign in to comment.