Skip to content

Commit

Permalink
Fix/datetimes (#283)
Browse files Browse the repository at this point in the history
* format start and end datetimes

* fix for filtering forecast/all

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* lint

* lint

* fix for filtering

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
peterdudfield and pre-commit-ci[bot] authored Aug 31, 2023
1 parent e50273c commit bb68850
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 27 deletions.
26 changes: 19 additions & 7 deletions src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
convert_forecasts_to_many_datetime_many_generation,
convert_location_sql_to_many_datetime_many_generation,
)
from utils import floor_30_minutes_dt, get_start_datetime
from utils import filter_forecast_values, floor_30_minutes_dt, get_start_datetime

logger = structlog.stdlib.get_logger()

Expand Down Expand Up @@ -122,22 +122,28 @@ def get_forecasts_from_database(

else:
# To speed up read time we only look at the last 12 hours of results, and take floor 30 mins
yesterday_start_datetime = floor_30_minutes_dt(
datetime.now(tz=timezone.utc) - timedelta(hours=12)
)
if start_datetime_utc is None:
start_datetime_utc = floor_30_minutes_dt(
datetime.now(tz=timezone.utc) - timedelta(hours=12)
)

start_created_utc = floor_30_minutes_dt(datetime.now(tz=timezone.utc) - timedelta(hours=12))

forecasts = get_all_gsp_ids_latest_forecast(
session=session,
start_created_utc=yesterday_start_datetime,
start_target_time=yesterday_start_datetime,
start_created_utc=start_created_utc,
start_target_time=start_datetime_utc,
preload_children=True,
model_name="blend",
end_target_time=end_datetime_utc,
)

if compact:
return convert_forecasts_to_many_datetime_many_generation(
forecasts=forecasts, historic=historic
forecasts=forecasts,
historic=historic,
start_datetime_utc=start_datetime_utc,
end_datetime_utc=end_datetime_utc,
)

else:
Expand All @@ -147,6 +153,12 @@ def get_forecasts_from_database(
else:
forecasts = [Forecast.from_orm(forecast) for forecast in forecasts]

forecasts = filter_forecast_values(
end_datetime_utc=end_datetime_utc,
forecasts=forecasts,
start_datetime_utc=start_datetime_utc,
)

# return as many forecasts
return ManyForecasts(forecasts=forecasts)

Expand Down
34 changes: 23 additions & 11 deletions src/gsp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Get GSP boundary data from eso """
from datetime import datetime
from typing import List, Optional, Union

import structlog
Expand All @@ -24,6 +23,7 @@
LocationWithGSPYields,
OneDatetimeManyForecastValues,
)
from utils import format_datetime

GSP_TOTAL = 317

Expand All @@ -47,8 +47,8 @@ def get_all_available_forecasts(
historic: Optional[bool] = True,
session: Session = Depends(get_session),
user: Auth0User = Security(get_user()),
start_datetime_utc: Optional[datetime] = None,
end_datetime_utc: Optional[datetime] = None,
start_datetime_utc: Optional[str] = None,
end_datetime_utc: Optional[str] = None,
compact: Optional[bool] = False,
) -> Union[ManyForecasts, List[OneDatetimeManyForecastValues]]:
"""### Get all forecasts for all GSPs
Expand All @@ -65,12 +65,15 @@ def get_all_available_forecasts(
#### Parameters
- **historic**: boolean that defaults to `true`, returning yesterday's and
today's forecasts for all GSPs
- **start_datetime_utc**: optional start datetime for the query.
- **end_datetime_utc**: optional end datetime for the query.
- **start_datetime_utc**: optional start datetime for the query. e.g '2023-08-12 10:00:00+00:00'
- **end_datetime_utc**: optional end datetime for the query. e.g '2023-08-12 14:00:00+00:00'
"""

logger.info(f"Get forecasts for all gsps. The option is {historic=} for user {user}")

start_datetime_utc = format_datetime(start_datetime_utc)
end_datetime_utc = format_datetime(end_datetime_utc)

forecasts = get_forecasts_from_database(
session=session,
historic=historic,
Expand Down Expand Up @@ -128,8 +131,8 @@ def get_forecasts_for_a_specific_gsp(
session: Session = Depends(get_session),
forecast_horizon_minutes: Optional[int] = None,
user: Auth0User = Security(get_user()),
start_datetime_utc: Optional[datetime] = None,
end_datetime_utc: Optional[datetime] = None,
start_datetime_utc: Optional[str] = None,
end_datetime_utc: Optional[str] = None,
) -> Union[Forecast, List[ForecastValue]]:
"""### Get recent forecast values for a specific GSP
Expand All @@ -155,6 +158,9 @@ def get_forecasts_for_a_specific_gsp(
logger.info(f"Get forecasts for gsp id {gsp_id} forecast of forecast with only values.")
logger.info(f"This is for user {user}")

start_datetime_utc = format_datetime(start_datetime_utc)
end_datetime_utc = format_datetime(end_datetime_utc)

if gsp_id > GSP_TOTAL:
return Response(None, status.HTTP_204_NO_CONTENT)

Expand Down Expand Up @@ -183,8 +189,8 @@ def get_truths_for_all_gsps(
regime: Optional[str] = None,
session: Session = Depends(get_session),
user: Auth0User = Security(get_user()),
start_datetime_utc: Optional[datetime] = None,
end_datetime_utc: Optional[datetime] = None,
start_datetime_utc: Optional[str] = None,
end_datetime_utc: Optional[str] = None,
compact: Optional[bool] = False,
) -> Union[List[LocationWithGSPYields], List[GSPYieldGroupByDatetime]]:
"""### Get PV_Live values for all GSPs for yesterday and today
Expand All @@ -207,6 +213,9 @@ def get_truths_for_all_gsps(
"""
logger.info(f"Get PV Live estimates values for all gsp id and regime {regime} for user {user}")

start_datetime_utc = format_datetime(start_datetime_utc)
end_datetime_utc = format_datetime(end_datetime_utc)

return get_truth_values_for_all_gsps_from_database(
session=session,
regime=regime,
Expand Down Expand Up @@ -253,8 +262,8 @@ def get_truths_for_a_specific_gsp(
request: Request,
gsp_id: int,
regime: Optional[str] = None,
start_datetime_utc: Optional[datetime] = None,
end_datetime_utc: Optional[datetime] = None,
start_datetime_utc: Optional[str] = None,
end_datetime_utc: Optional[str] = None,
session: Session = Depends(get_session),
user: Auth0User = Security(get_user()),
) -> List[GSPYield]:
Expand All @@ -280,6 +289,9 @@ def get_truths_for_a_specific_gsp(
f"Get PV Live estimates values for gsp id {gsp_id} " f"and regime {regime} for user {user}"
)

start_datetime_utc = format_datetime(start_datetime_utc)
end_datetime_utc = format_datetime(end_datetime_utc)

if gsp_id > GSP_TOTAL:
return Response(None, status.HTTP_204_NO_CONTENT)

Expand Down
9 changes: 8 additions & 1 deletion src/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,10 @@ def convert_location_sql_to_many_datetime_many_generation(


def convert_forecasts_to_many_datetime_many_generation(
forecasts: List[ForecastSQL], historic: bool = True
forecasts: List[ForecastSQL],
historic: bool = True,
start_datetime_utc: Optional[datetime] = None,
end_datetime_utc: Optional[datetime] = None,
) -> List[OneDatetimeManyForecastValues]:
"""Change forecasts to list of OneDatetimeManyForecastValues
Expand All @@ -134,6 +137,10 @@ def convert_forecasts_to_many_datetime_many_generation(

for forecast_value in forecast_values:
datetime_utc = forecast_value.target_time
if start_datetime_utc is not None and datetime_utc < start_datetime_utc:
continue
if end_datetime_utc is not None and datetime_utc > end_datetime_utc:
continue
forecast_mw = round(forecast_value.expected_power_generation_megawatts, 2)

# if the datetime object is not in the dictionary, add it
Expand Down
2 changes: 1 addition & 1 deletion src/tests/test_gsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_read_latest_all_gsp(db_session, api_client):

r = ManyForecasts(**response.json())
assert len(r.forecasts) == 10
assert len(r.forecasts[0].forecast_values) == 112
assert len(r.forecasts[0].forecast_values) == 40


def test_read_latest_gsp_id_greater_than_total(db_session, api_client):
Expand Down
67 changes: 60 additions & 7 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
""" Utils functions for main.py """
import os
from datetime import datetime, timedelta
from typing import Optional, Union
from typing import List, Optional, Union

import numpy as np
import structlog
from nowcasting_datamodel.models import ForecastValue
from nowcasting_datamodel.models import Forecast, ForecastValue
from pydantic import Field
from pytz import timezone

Expand Down Expand Up @@ -64,6 +64,24 @@ def floor_6_hours_dt(dt: datetime):
return dt


def format_datetime(datetime_str: str = None):
"""
Format datetime string to datetime object
If None return None, if not timezone, add UTC
:param datetime_str:
:return:
"""
if datetime_str is None:
return None

else:
datetime_output = datetime.fromisoformat(datetime_str)
if datetime_output.tzinfo is None:
datetime_output = utc.localize(datetime_output)
return datetime_output


def get_start_datetime(
n_history_days: Optional[Union[str, int]] = None, start_datetime: Optional[datetime] = None
) -> datetime:
Expand All @@ -80,11 +98,9 @@ def get_start_datetime(
:return: start datetime
"""

if (
start_datetime is None
or start_datetime >= datetime.now(tz=timezone.utc)
or datetime.now(tz=timezone.utc) - start_datetime > timedelta(days=3)
):
now = datetime.now(tz=utc)

if start_datetime is None or now - start_datetime > timedelta(days=3):
if n_history_days is None:
n_history_days = os.getenv("N_HISTORY_DAYS", "yesterday")

Expand Down Expand Up @@ -163,3 +179,40 @@ def format_plevels(national_forecast_value: NationalForecastValue):
national_forecast_value.plevels["plevel_90"] = (
national_forecast_value.expected_power_generation_megawatts * 1.2
)


def filter_forecast_values(
forecasts: List[Forecast],
end_datetime_utc: Optional[datetime] = None,
start_datetime_utc: Optional[datetime] = None,
) -> List[Forecast]:
"""
Filter forecast values by start and end datetime
:param forecasts: list of forecasts
:param end_datetime_utc: start datetime
:param start_datetime_utc: e
:return:
"""
if start_datetime_utc is not None or end_datetime_utc is not None:
logger.info(f"Filtering forecasts from {start_datetime_utc} to {end_datetime_utc}")
forecasts_filtered = []
for forecast in forecasts:
forecast_values = forecast.forecast_values
if start_datetime_utc is not None:
forecast_values = [
forecast_value
for forecast_value in forecast_values
if forecast_value.target_time >= start_datetime_utc
]
if end_datetime_utc is not None:
forecast_values = [
forecast_value
for forecast_value in forecast_values
if forecast_value.target_time <= end_datetime_utc
]
forecast.forecast_values = forecast_values

forecasts_filtered.append(forecast)
forecasts = forecasts_filtered
return forecasts

0 comments on commit bb68850

Please sign in to comment.