diff --git a/src/database.py b/src/database.py index c27c2439..7ef3b37a 100644 --- a/src/database.py +++ b/src/database.py @@ -153,7 +153,11 @@ def get_forecasts_from_database( else: forecasts = [Forecast.from_orm(forecast) for forecast in forecasts] - forecasts = filter_forecast_values(end_datetime_utc, forecasts, start_datetime_utc) + forecasts = filter_forecast_values( + end_datetime_utc=end_datetime_utc, + forecasts=forecasts, + start_datetime_utc=start_datetime_utc, + ) # return as many forecasts return ManyForecasts(forecasts=forecasts) diff --git a/src/tests/test_gsp.py b/src/tests/test_gsp.py index 9a924c27..c01a8460 100644 --- a/src/tests/test_gsp.py +++ b/src/tests/test_gsp.py @@ -60,7 +60,7 @@ def test_read_latest_all_gsp(db_session, api_client): r = ManyForecasts(**response.json()) assert len(r.forecasts) == 10 - assert len(r.forecasts[0].forecast_values) == 112 + assert len(r.forecasts[0].forecast_values) == 40 def test_read_latest_gsp_id_greater_than_total(db_session, api_client): diff --git a/src/utils.py b/src/utils.py index 15908713..fc0d60c5 100644 --- a/src/utils.py +++ b/src/utils.py @@ -5,7 +5,7 @@ import numpy as np import structlog -from nowcasting_datamodel.models import ForecastSQL, ForecastValue +from nowcasting_datamodel.models import Forecast, ForecastValue from pydantic import Field from pytz import timezone @@ -182,10 +182,10 @@ def format_plevels(national_forecast_value: NationalForecastValue): def filter_forecast_values( - forecasts: List[ForecastSQL], + forecasts: List[Forecast], end_datetime_utc: Optional[datetime] = None, start_datetime_utc: Optional[datetime] = None, -) -> List[ForecastSQL]: +) -> List[Forecast]: """ Filter forecast values by start and end datetime