Skip to content

Commit

Permalink
switch to interval_[start/end] + allow forecast non-zero first step (#64
Browse files Browse the repository at this point in the history
)

* switch to interval_[start/end] + allow forecast to have non-zero first step time

* fix test

* remove unneeded param check
  • Loading branch information
dfulu authored Nov 14, 2024
1 parent cc5d43c commit f9fd827
Show file tree
Hide file tree
Showing 16 changed files with 250 additions and 291 deletions.
92 changes: 46 additions & 46 deletions ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from typing import Dict, List, Optional
from typing_extensions import Self

from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator

from ocf_data_sampler.constants import NWP_PROVIDERS

logger = logging.getLogger(__name__)
Expand All @@ -40,6 +41,45 @@ class General(Base):
)


class TimeWindowMixin(Base):
"""Mixin class, to add interval start, end and resolution minutes"""

time_resolution_minutes: int = Field(
...,
gt=0,
description="The temporal resolution of the data in minutes",
)

interval_start_minutes: int = Field(
...,
description="Data interval starts at `t0 + interval_start_minutes`",
)

interval_end_minutes: int = Field(
...,
description="Data interval ends at `t0 + interval_end_minutes`",
)

@model_validator(mode='after')
def check_interval_range(cls, values):
if values.interval_start_minutes > values.interval_end_minutes:
raise ValueError('interval_start_minutes must be <= interval_end_minutes')
return values

@field_validator("interval_start_minutes")
def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
return v

@field_validator("interval_end_minutes")
def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
return v



# noinspection PyMethodParameters
class DropoutMixin(Base):
"""Mixin class, to add dropout minutes"""
Expand Down Expand Up @@ -76,54 +116,18 @@ def dropout_instructions_consistent(self) -> Self:
return self


# noinspection PyMethodParameters
class TimeWindowMixin(Base):
"""Time resolution mix in"""

time_resolution_minutes: int = Field(
...,
gt=0,
description="The temporal resolution of the data in minutes",
)

forecast_minutes: int = Field(
...,
ge=0,
description="how many minutes to forecast in the future",
)
history_minutes: int = Field(
...,
ge=0,
description="how many historic minutes to use",
)

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v, values) -> int:
if v % values.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v, values) -> int:
if v % values.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class SpatialWindowMixin(Base):
"""Mixin class, to add path and image size"""

image_size_pixels_height: int = Field(
...,
ge=0,
description="The number of pixels of the height of the region of interest",
)

image_size_pixels_width: int = Field(
...,
ge=0,
description="The number of pixels of the width of the region of interest",
)

Expand All @@ -140,10 +144,6 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
..., description="the satellite channels that are used"
)

live_delay_minutes: int = Field(
..., description="The expected delay in minutes of the satellite data"
)


# noinspection PyMethodParameters
class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
Expand All @@ -169,6 +169,7 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
" the maximum forecast horizon of the NWP and the requested forecast length.",
)


@field_validator("provider")
def validate_provider(cls, v: str) -> str:
"""Validate 'provider'"""
Expand Down Expand Up @@ -227,11 +228,10 @@ class Site(TimeWindowMixin, DropoutMixin):
# TODO validate the csv for metadata



# noinspection PyPep8Naming
class InputData(Base):
"""
Input data model.
"""
"""Input data model"""

satellite: Optional[Satellite] = None
nwp: Optional[MultiNWP] = None
Expand Down
115 changes: 40 additions & 75 deletions ocf_data_sampler/select/find_contiguous_time_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,25 @@ def find_contiguous_time_periods(

def trim_contiguous_time_periods(
contiguous_time_periods: pd.DataFrame,
history_duration: pd.Timedelta,
forecast_duration: pd.Timedelta,
interval_start: pd.Timedelta,
interval_end: pd.Timedelta,
) -> pd.DataFrame:
"""Trim the contiguous time periods to allow for history and forecast durations.
Args:
contiguous_time_periods: DataFrame where each row represents a single time period. The
DataFrame must have `start_dt` and `end_dt` columns.
history_duration: Length of the historical slice used for a sample
forecast_duration: Length of the forecast slice used for a sample
interval_start: The start of the interval with respect to t0
interval_end: The end of the interval with respect to t0
Returns:
The contiguous_time_periods DataFrame with the `start_dt` and `end_dt` columns updated.
"""
contiguous_time_periods = contiguous_time_periods.copy()

contiguous_time_periods["start_dt"] += history_duration
contiguous_time_periods["end_dt"] -= forecast_duration
contiguous_time_periods["start_dt"] -= interval_start
contiguous_time_periods["end_dt"] -= interval_end

valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
Expand All @@ -92,24 +92,24 @@ def trim_contiguous_time_periods(

def find_contiguous_t0_periods(
datetimes: pd.DatetimeIndex,
history_duration: pd.Timedelta,
forecast_duration: pd.Timedelta,
interval_start: pd.Timedelta,
interval_end: pd.Timedelta,
sample_period_duration: pd.Timedelta,
) -> pd.DataFrame:
"""Return a pd.DataFrame where each row records the boundary of a contiguous time period.
Args:
datetimes: pd.DatetimeIndex. Must be sorted.
history_duration: Length of the historical slice used for each sample
forecast_duration: Length of the forecast slice used for each sample
interval_start: The start of the interval with respect to t0
interval_end: The end of the interval with respect to t0
sample_period_duration: The sample frequency of the timeseries
Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""
total_duration = history_duration + forecast_duration
total_duration = interval_end - interval_start

contiguous_time_periods = find_contiguous_time_periods(
datetimes=datetimes,
Expand All @@ -119,101 +119,66 @@ def find_contiguous_t0_periods(

contiguous_t0_periods = trim_contiguous_time_periods(
contiguous_time_periods=contiguous_time_periods,
history_duration=history_duration,
forecast_duration=forecast_duration,
interval_start=interval_start,
interval_end=interval_end,
)

assert len(contiguous_t0_periods) > 0

return contiguous_t0_periods


def _find_contiguous_t0_periods_nwp(
ds,
history_duration: pd.Timedelta,
forecast_duration: pd.Timedelta,
max_staleness: pd.Timedelta | None = None,
max_dropout: pd.Timedelta = pd.Timedelta(0),
time_dim: str = "init_time_utc",
end_buffer: pd.Timedelta = pd.Timedelta(0),
):

assert "step" in ds.coords
# It is possible to use up to this amount of max staleness for the dataset and slice
# required
possible_max_staleness = (
pd.Timedelta(ds["step"].max().item())
- forecast_duration
- end_buffer
)

# If max_staleness is set to None we set it based on the max step ahead of the input
# forecast data
if max_staleness is None:
max_staleness = possible_max_staleness
else:
# Make sure the max acceptable staleness isn't longer than the max possible
assert max_staleness <= possible_max_staleness
max_staleness = max_staleness

contiguous_time_periods = find_contiguous_t0_periods_nwp(
datetimes=pd.DatetimeIndex(ds[time_dim]),
history_duration=history_duration,
max_staleness=max_staleness,
max_dropout=max_dropout,
)
return contiguous_time_periods



def find_contiguous_t0_periods_nwp(
datetimes: pd.DatetimeIndex,
history_duration: pd.Timedelta,
init_times: pd.DatetimeIndex,
interval_start: pd.Timedelta,
max_staleness: pd.Timedelta,
max_dropout: pd.Timedelta = pd.Timedelta(0),
first_forecast_step: pd.Timedelta = pd.Timedelta(0),

) -> pd.DataFrame:
"""Get all time periods from the NWP init times which are valid as t0 datetimes.
Args:
datetimes: Sorted pd.DatetimeIndex
history_duration: Length of the historical slice used for a sample
max_staleness: Up to how long after an NWP forecast init_time are we willing to use the
forecast. Each init time will only be used up to this t0 time regardless of the forecast
valid time.
init_times: The initialisation times of the available forecasts
interval_start: The start of the desired data interval with respect to t0
max_staleness: Up to how long after an init time are we willing to use the forecast. Each
init time will only be used up to this t0 time regardless of the forecast valid time.
max_dropout: What is the maximum amount of dropout that will be used. This must be <=
max_staleness.
first_forecast_step: The timedelta of the first step of the forecast. By default we assume
the first valid time of the forecast is the same as its init time.
Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""
# Sanity checks.
assert len(datetimes) > 0
assert datetimes.is_monotonic_increasing
assert datetimes.is_unique
assert history_duration >= pd.Timedelta(0)
assert len(init_times) > 0
assert init_times.is_monotonic_increasing
assert init_times.is_unique
assert max_staleness >= pd.Timedelta(0)
assert max_dropout <= max_staleness
assert pd.Timedelta(0) <= max_dropout <= max_staleness

hist_drop_buffer = max(history_duration, max_dropout)
hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout)

# Store contiguous periods
contiguous_periods = []

# Start first period allowing for history slice and max dropout
start_this_period = datetimes[0] + hist_drop_buffer
# Begin the first period allowing for the time to the first_forecast_step, the length of the
# interval sampled from before t0, and the dropout
start_this_period = init_times[0] + hist_drop_buffer

# The first forecast is valid up to the max staleness
end_this_period = datetimes[0] + max_staleness

for dt_init in datetimes[1:]:
# If the previous init time becomes stale before the next init becomes valid whilst also
# considering dropout - then the contiguous period breaks, and new starts with considering
# dropout and history duration
if end_this_period < dt_init + max_dropout:
end_this_period = init_times[0] + max_staleness

for dt_init in init_times[1:]:
# If the previous init time becomes stale before the next init becomes valid (whilst also
# considering dropout) then the contiguous period breaks
# Else if the previous init time becomes stale before the fist step of the next forecast
# then this also causes a break in the contiguous period
if (end_this_period < dt_init + max(max_dropout, first_forecast_step)):
contiguous_periods.append([start_this_period, end_this_period])

# And start a new period
# The new period begins with the same conditions as the first period
start_this_period = dt_init + hist_drop_buffer
end_this_period = dt_init + max_staleness

Expand Down
Loading

0 comments on commit f9fd827

Please sign in to comment.