Skip to content

Commit

Permalink
fix for filtering forecast/all
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdudfield committed Aug 31, 2023
1 parent ac419f6 commit 911e9c1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 14 deletions.
21 changes: 14 additions & 7 deletions src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
convert_forecasts_to_many_datetime_many_generation,
convert_location_sql_to_many_datetime_many_generation,
)
from utils import floor_30_minutes_dt, get_start_datetime
from utils import floor_30_minutes_dt, get_start_datetime, filter_forecast_values

logger = structlog.stdlib.get_logger()

Expand Down Expand Up @@ -122,22 +122,27 @@ def get_forecasts_from_database(

else:
# To speed up read time we only look at the last 12 hours of results, and take floor 30 mins
yesterday_start_datetime = floor_30_minutes_dt(
datetime.now(tz=timezone.utc) - timedelta(hours=12)
)
if start_datetime_utc is None:
start_datetime_utc = floor_30_minutes_dt(
datetime.now(tz=timezone.utc) - timedelta(hours=12)
)

start_created_utc = floor_30_minutes_dt(
datetime.now(tz=timezone.utc) - timedelta(hours=12)
)

forecasts = get_all_gsp_ids_latest_forecast(
session=session,
start_created_utc=yesterday_start_datetime,
start_target_time=yesterday_start_datetime,
start_created_utc=start_created_utc,
start_target_time=start_datetime_utc,
preload_children=True,
model_name="blend",
end_target_time=end_datetime_utc,
)

if compact:
return convert_forecasts_to_many_datetime_many_generation(
forecasts=forecasts, historic=historic
forecasts=forecasts, historic=historic, start_datetime_utc=start_datetime_utc, end_datetime_utc=end_datetime_utc
)

else:
Expand All @@ -147,6 +152,8 @@ def get_forecasts_from_database(
else:
forecasts = [Forecast.from_orm(forecast) for forecast in forecasts]

forecasts = filter_forecast_values(end_datetime_utc, forecasts, start_datetime_utc)

# return as many forecasts
return ManyForecasts(forecasts=forecasts)

Expand Down
51 changes: 44 additions & 7 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
""" Utils functions for main.py """
import os
from datetime import datetime, timedelta
from typing import Optional, Union
from typing import List, Optional, Union

import numpy as np
import structlog
from nowcasting_datamodel.models import ForecastValue
from nowcasting_datamodel.models import ForecastValue, ForecastSQL
from pydantic import Field
from pytz import timezone

from database import logger

logger = structlog.stdlib.get_logger()

europe_london_tz = timezone("Europe/London")
Expand Down Expand Up @@ -71,9 +73,10 @@ def format_datetime(datetime_str: str = None):
else:
datetime_output = datetime.fromisoformat(datetime_str)
if datetime_output.tzinfo is None:
datetime_output = europe_london_tz.localize(datetime_output)
datetime_output = utc.localize(datetime_output)
return datetime_output


def get_start_datetime(
n_history_days: Optional[Union[str, int]] = None, start_datetime: Optional[datetime] = None
) -> datetime:
Expand All @@ -92,10 +95,7 @@ def get_start_datetime(

now = datetime.now(tz=utc)

if (
start_datetime is None
or now - start_datetime > timedelta(days=3)
):
if start_datetime is None or now - start_datetime > timedelta(days=3):
if n_history_days is None:
n_history_days = os.getenv("N_HISTORY_DAYS", "yesterday")

Expand Down Expand Up @@ -174,3 +174,40 @@ def format_plevels(national_forecast_value: NationalForecastValue):
national_forecast_value.plevels["plevel_90"] = (
national_forecast_value.expected_power_generation_megawatts * 1.2
)


def filter_forecast_values(
forecasts: List[ForecastSQL],
end_datetime_utc: Optional[datetime] = None,
start_datetime_utc: Optional[datetime] = None,
) -> List[ForecastSQL]:
"""
Filter forecast values by start and end datetime
:param forecasts: list of forecasts
:param end_datetime_utc: start datetime
:param start_datetime_utc: e
:return:
"""
if start_datetime_utc is not None or end_datetime_utc is not None:
logger.info(f"Filtering forecasts from {start_datetime_utc} to {end_datetime_utc}")
forecasts_filtered = []
for forecast in forecasts:
forecast_values = forecast.forecast_values
if start_datetime_utc is not None:
forecast_values = [
forecast_value
for forecast_value in forecast_values
if forecast_value.target_time >= start_datetime_utc
]
if end_datetime_utc is not None:
forecast_values = [
forecast_value
for forecast_value in forecast_values
if forecast_value.target_time <= end_datetime_utc
]
forecast.forecast_values = forecast_values

forecasts_filtered.append(forecast)
forecasts = forecasts_filtered
return forecasts

0 comments on commit 911e9c1

Please sign in to comment.