diff --git a/src/database.py b/src/database.py index 51be1c6..973e81f 100644 --- a/src/database.py +++ b/src/database.py @@ -40,7 +40,7 @@ convert_forecasts_to_many_datetime_many_generation, convert_location_sql_to_many_datetime_many_generation, ) -from utils import floor_30_minutes_dt, get_start_datetime +from utils import floor_30_minutes_dt, get_start_datetime, filter_forecast_values logger = structlog.stdlib.get_logger() @@ -122,14 +122,19 @@ def get_forecasts_from_database( else: # To speed up read time we only look at the last 12 hours of results, and take floor 30 mins - yesterday_start_datetime = floor_30_minutes_dt( - datetime.now(tz=timezone.utc) - timedelta(hours=12) - ) + if start_datetime_utc is None: + start_datetime_utc = floor_30_minutes_dt( + datetime.now(tz=timezone.utc) - timedelta(hours=12) + ) + + start_created_utc = floor_30_minutes_dt( + datetime.now(tz=timezone.utc) - timedelta(hours=12) + ) forecasts = get_all_gsp_ids_latest_forecast( session=session, - start_created_utc=yesterday_start_datetime, - start_target_time=yesterday_start_datetime, + start_created_utc=start_created_utc, + start_target_time=start_datetime_utc, preload_children=True, model_name="blend", end_target_time=end_datetime_utc, @@ -137,7 +142,7 @@ def get_forecasts_from_database( if compact: return convert_forecasts_to_many_datetime_many_generation( - forecasts=forecasts, historic=historic + forecasts=forecasts, historic=historic, start_datetime_utc=start_datetime_utc, end_datetime_utc=end_datetime_utc ) else: @@ -147,6 +152,8 @@ def get_forecasts_from_database( else: forecasts = [Forecast.from_orm(forecast) for forecast in forecasts] + forecasts = filter_forecast_values(end_datetime_utc, forecasts, start_datetime_utc) + # return as many forecasts return ManyForecasts(forecasts=forecasts) diff --git a/src/utils.py b/src/utils.py index 4bc737c..41a31d6 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,14 +1,16 @@ """ Utils functions for main.py """ import os from datetime import datetime, timedelta -from typing import Optional, Union +from typing import List, Optional, Union import numpy as np import structlog -from nowcasting_datamodel.models import ForecastValue +from nowcasting_datamodel.models import ForecastValue, ForecastSQL from pydantic import Field from pytz import timezone +from database import logger + logger = structlog.stdlib.get_logger() europe_london_tz = timezone("Europe/London") @@ -71,9 +73,10 @@ def format_datetime(datetime_str: str = None): else: datetime_output = datetime.fromisoformat(datetime_str) if datetime_output.tzinfo is None: - datetime_output = europe_london_tz.localize(datetime_output) + datetime_output = utc.localize(datetime_output) return datetime_output + def get_start_datetime( n_history_days: Optional[Union[str, int]] = None, start_datetime: Optional[datetime] = None ) -> datetime: @@ -92,10 +95,7 @@ def get_start_datetime( now = datetime.now(tz=utc) - if ( - start_datetime is None - or now - start_datetime > timedelta(days=3) - ): + if start_datetime is None or now - start_datetime > timedelta(days=3): if n_history_days is None: n_history_days = os.getenv("N_HISTORY_DAYS", "yesterday") @@ -174,3 +174,40 @@ def format_plevels(national_forecast_value: NationalForecastValue): national_forecast_value.plevels["plevel_90"] = ( national_forecast_value.expected_power_generation_megawatts * 1.2 ) + + +def filter_forecast_values( + forecasts: List[ForecastSQL], + end_datetime_utc: Optional[datetime] = None, + start_datetime_utc: Optional[datetime] = None, +) -> List[ForecastSQL]: + """ + Filter forecast values by start and end datetime + + :param forecasts: list of forecasts + :param end_datetime_utc: start datetime + :param start_datetime_utc: e + :return: + """ + if start_datetime_utc is not None or end_datetime_utc is not None: + logger.info(f"Filtering forecasts from {start_datetime_utc} to {end_datetime_utc}") + forecasts_filtered = [] + for forecast in forecasts: + forecast_values = forecast.forecast_values + if start_datetime_utc is not None: + forecast_values = [ + forecast_value + for forecast_value in forecast_values + if forecast_value.target_time >= start_datetime_utc + ] + if end_datetime_utc is not None: + forecast_values = [ + forecast_value + for forecast_value in forecast_values + if forecast_value.target_time <= end_datetime_utc + ] + forecast.forecast_values = forecast_values + + forecasts_filtered.append(forecast) + forecasts = forecasts_filtered + return forecasts