diff --git a/notebooks/recipes/energy_demand_forecasting/demand_forecast_zeroshot_recipe_minimal.ipynb b/notebooks/recipes/energy_demand_forecasting/demand_forecast_zeroshot_recipe_minimal.ipynb index 0cd46783..dd1bf605 100644 --- a/notebooks/recipes/energy_demand_forecasting/demand_forecast_zeroshot_recipe_minimal.ipynb +++ b/notebooks/recipes/energy_demand_forecasting/demand_forecast_zeroshot_recipe_minimal.ipynb @@ -33,7 +33,9 @@ ], "source": [ "import pathlib\n", + "\n", "import pandas as pd\n", + "\n", "from tsfm_public import TimeSeriesForecastingPipeline, TinyTimeMixerForPrediction\n", "from tsfm_public.toolkit.visualization import plot_predictions" ] @@ -57,6 +59,8 @@ ], "source": [ "import tsfm_public\n", + "\n", + "\n", "tsfm_public.__version__" ] }, @@ -490,7 +494,12 @@ ], "source": [ "pipeline = TimeSeriesForecastingPipeline(\n", - " zeroshot_model, timestamp_column=timestamp_column, target_columns=[target_column], explode_forecasts=True, freq=\"h\", id_columns=[]\n", + " zeroshot_model,\n", + " timestamp_column=timestamp_column,\n", + " target_columns=[target_column],\n", + " explode_forecasts=True,\n", + " freq=\"h\",\n", + " id_columns=[],\n", ")\n", "zeroshot_forecast = pipeline(data)\n", "zeroshot_forecast.head()" @@ -529,7 +538,7 @@ " timestamp_column=timestamp_column,\n", " channel=target_column,\n", " indices=[-1],\n", - " num_plots=1\n", + " num_plots=1,\n", ")" ] }, diff --git a/tsfm_public/toolkit/visualization.py b/tsfm_public/toolkit/visualization.py index d47ea0ee..eb52d6a4 100644 --- a/tsfm_public/toolkit/visualization.py +++ b/tsfm_public/toolkit/visualization.py @@ -286,7 +286,9 @@ def plot_predictions( using_pipeline = False plot_test_data = True else: - raise RuntimeError("You must provide either input_df and predictions_df, or dset and model, or input_df and exploded_predictions_df.") + raise RuntimeError( + "You must provide either input_df and predictions_df, or dset and model, or input_df and exploded_predictions_df." + ) if plot_context is None: plot_context = 2 * prediction_length @@ -317,9 +319,11 @@ def plot_predictions( y = y.values border = ts_y[-prediction_length] plot_title = f"Example {indices[i]}" - + elif using_pipeline: - ts_y_hat = create_timestamps(exploded_predictions_df[timestamp_column].iloc[0], freq=freq, periods=prediction_length) + ts_y_hat = create_timestamps( + exploded_predictions_df[timestamp_column].iloc[0], freq=freq, periods=prediction_length + ) y_hat = exploded_predictions_df[f"{channel}_prediction"] # get context