Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to interval_[start/end] + allow forecast non-zero first step #64

Merged
merged 5 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 46 additions & 46 deletions ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from typing import Dict, List, Optional
from typing_extensions import Self

from pydantic import BaseModel, Field, RootModel, field_validator, model_validator
from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator

from ocf_data_sampler.constants import NWP_PROVIDERS

logger = logging.getLogger(__name__)
Expand All @@ -40,6 +41,45 @@ class General(Base):
)


class TimeWindowMixin(Base):
"""Mixin class, to add interval start, end and resolution minutes"""

time_resolution_minutes: int = Field(
...,
gt=0,
description="The temporal resolution of the data in minutes",
)

interval_start_minutes: int = Field(
...,
description="Data interval starts at `t0 + interval_start_minutes`",
)

interval_end_minutes: int = Field(
...,
description="Data interval ends at `t0 + interval_end_minutes`",
)

@model_validator(mode='after')
def check_interval_range(cls, values):
if values.interval_start_minutes > values.interval_end_minutes:
raise ValueError('interval_start_minutes must be <= interval_end_minutes')
return values

@field_validator("interval_start_minutes")
def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes")
return v

@field_validator("interval_end_minutes")
def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes")
return v



# noinspection PyMethodParameters
class DropoutMixin(Base):
"""Mixin class, to add dropout minutes"""
Expand Down Expand Up @@ -76,54 +116,18 @@ def dropout_instructions_consistent(self) -> Self:
return self


# noinspection PyMethodParameters
class TimeWindowMixin(Base):
"""Time resolution mix in"""

time_resolution_minutes: int = Field(
...,
gt=0,
description="The temporal resolution of the data in minutes",
)

forecast_minutes: int = Field(
...,
ge=0,
description="how many minutes to forecast in the future",
)
history_minutes: int = Field(
...,
ge=0,
description="how many historic minutes to use",
)

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v, values) -> int:
if v % values.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v, values) -> int:
if v % values.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class SpatialWindowMixin(Base):
"""Mixin class, to add path and image size"""

image_size_pixels_height: int = Field(
...,
ge=0,
description="The number of pixels of the height of the region of interest",
)

image_size_pixels_width: int = Field(
...,
ge=0,
description="The number of pixels of the width of the region of interest",
)

Expand All @@ -140,10 +144,6 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
..., description="the satellite channels that are used"
)

live_delay_minutes: int = Field(
..., description="The expected delay in minutes of the satellite data"
)


# noinspection PyMethodParameters
class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
Expand All @@ -169,6 +169,7 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin):
" the maximum forecast horizon of the NWP and the requested forecast length.",
)


@field_validator("provider")
def validate_provider(cls, v: str) -> str:
"""Validate 'provider'"""
Expand Down Expand Up @@ -227,11 +228,10 @@ class Site(TimeWindowMixin, DropoutMixin):
# TODO validate the csv for metadata



# noinspection PyPep8Naming
class InputData(Base):
"""
Input data model.
"""
"""Input data model"""

satellite: Optional[Satellite] = None
nwp: Optional[MultiNWP] = None
Expand Down
115 changes: 40 additions & 75 deletions ocf_data_sampler/select/find_contiguous_time_periods.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,25 @@ def find_contiguous_time_periods(

def trim_contiguous_time_periods(
contiguous_time_periods: pd.DataFrame,
history_duration: pd.Timedelta,
forecast_duration: pd.Timedelta,
interval_start: pd.Timedelta,
interval_end: pd.Timedelta,
) -> pd.DataFrame:
"""Trim the contiguous time periods to allow for history and forecast durations.

Args:
contiguous_time_periods: DataFrame where each row represents a single time period. The
DataFrame must have `start_dt` and `end_dt` columns.
history_duration: Length of the historical slice used for a sample
forecast_duration: Length of the forecast slice used for a sample
interval_start: The start of the interval with respect to t0
interval_end: The end of the interval with respect to t0


Returns:
The contiguous_time_periods DataFrame with the `start_dt` and `end_dt` columns updated.
"""
contiguous_time_periods = contiguous_time_periods.copy()

contiguous_time_periods["start_dt"] += history_duration
contiguous_time_periods["end_dt"] -= forecast_duration
contiguous_time_periods["start_dt"] -= interval_start
contiguous_time_periods["end_dt"] -= interval_end

valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"]
contiguous_time_periods = contiguous_time_periods.loc[valid_mask]
Expand All @@ -92,24 +92,24 @@ def trim_contiguous_time_periods(

def find_contiguous_t0_periods(
datetimes: pd.DatetimeIndex,
history_duration: pd.Timedelta,
forecast_duration: pd.Timedelta,
interval_start: pd.Timedelta,
interval_end: pd.Timedelta,
sample_period_duration: pd.Timedelta,
) -> pd.DataFrame:
"""Return a pd.DataFrame where each row records the boundary of a contiguous time period.

Args:
datetimes: pd.DatetimeIndex. Must be sorted.
history_duration: Length of the historical slice used for each sample
forecast_duration: Length of the forecast slice used for each sample
interval_start: The start of the interval with respect to t0
interval_end: The end of the interval with respect to t0
sample_period_duration: The sample frequency of the timeseries


Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""
total_duration = history_duration + forecast_duration
total_duration = interval_end - interval_start

contiguous_time_periods = find_contiguous_time_periods(
datetimes=datetimes,
Expand All @@ -119,101 +119,66 @@ def find_contiguous_t0_periods(

contiguous_t0_periods = trim_contiguous_time_periods(
contiguous_time_periods=contiguous_time_periods,
history_duration=history_duration,
forecast_duration=forecast_duration,
interval_start=interval_start,
interval_end=interval_end,
)

assert len(contiguous_t0_periods) > 0

return contiguous_t0_periods


def _find_contiguous_t0_periods_nwp(
ds,
history_duration: pd.Timedelta,
forecast_duration: pd.Timedelta,
max_staleness: pd.Timedelta | None = None,
max_dropout: pd.Timedelta = pd.Timedelta(0),
time_dim: str = "init_time_utc",
end_buffer: pd.Timedelta = pd.Timedelta(0),
):

assert "step" in ds.coords
# It is possible to use up to this amount of max staleness for the dataset and slice
# required
possible_max_staleness = (
pd.Timedelta(ds["step"].max().item())
- forecast_duration
- end_buffer
)

# If max_staleness is set to None we set it based on the max step ahead of the input
# forecast data
if max_staleness is None:
max_staleness = possible_max_staleness
else:
# Make sure the max acceptable staleness isn't longer than the max possible
assert max_staleness <= possible_max_staleness
max_staleness = max_staleness

contiguous_time_periods = find_contiguous_t0_periods_nwp(
datetimes=pd.DatetimeIndex(ds[time_dim]),
history_duration=history_duration,
max_staleness=max_staleness,
max_dropout=max_dropout,
)
return contiguous_time_periods



def find_contiguous_t0_periods_nwp(
datetimes: pd.DatetimeIndex,
history_duration: pd.Timedelta,
init_times: pd.DatetimeIndex,
interval_start: pd.Timedelta,
max_staleness: pd.Timedelta,
max_dropout: pd.Timedelta = pd.Timedelta(0),
first_forecast_step: pd.Timedelta = pd.Timedelta(0),

) -> pd.DataFrame:
"""Get all time periods from the NWP init times which are valid as t0 datetimes.

Args:
datetimes: Sorted pd.DatetimeIndex
history_duration: Length of the historical slice used for a sample
max_staleness: Up to how long after an NWP forecast init_time are we willing to use the
forecast. Each init time will only be used up to this t0 time regardless of the forecast
valid time.
init_times: The initialisation times of the available forecasts
interval_start: The start of the desired data interval with respect to t0
max_staleness: Up to how long after an init time are we willing to use the forecast. Each
init time will only be used up to this t0 time regardless of the forecast valid time.
max_dropout: What is the maximum amount of dropout that will be used. This must be <=
max_staleness.
first_forecast_step: The timedelta of the first step of the forecast. By default we assume
the first valid time of the forecast is the same as its init time.

Returns:
pd.DataFrame where each row represents a single time period. The pd.DataFrame
has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime').
"""
# Sanity checks.
assert len(datetimes) > 0
assert datetimes.is_monotonic_increasing
assert datetimes.is_unique
assert history_duration >= pd.Timedelta(0)
assert len(init_times) > 0
assert init_times.is_monotonic_increasing
assert init_times.is_unique
assert max_staleness >= pd.Timedelta(0)
assert max_dropout <= max_staleness
assert pd.Timedelta(0) <= max_dropout <= max_staleness

hist_drop_buffer = max(history_duration, max_dropout)
hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout)

# Store contiguous periods
contiguous_periods = []

# Start first period allowing for history slice and max dropout
start_this_period = datetimes[0] + hist_drop_buffer
# Begin the first period allowing for the time to the first_forecast_step, the length of the
# interval sampled from before t0, and the dropout
start_this_period = init_times[0] + hist_drop_buffer

# The first forecast is valid up to the max staleness
end_this_period = datetimes[0] + max_staleness

for dt_init in datetimes[1:]:
# If the previous init time becomes stale before the next init becomes valid whilst also
# considering dropout - then the contiguous period breaks, and new starts with considering
# dropout and history duration
if end_this_period < dt_init + max_dropout:
end_this_period = init_times[0] + max_staleness

for dt_init in init_times[1:]:
# If the previous init time becomes stale before the next init becomes valid (whilst also
# considering dropout) then the contiguous period breaks
# Else if the previous init time becomes stale before the fist step of the next forecast
# then this also causes a break in the contiguous period
if (end_this_period < dt_init + max(max_dropout, first_forecast_step)):
contiguous_periods.append([start_this_period, end_this_period])

# And start a new period
# The new period begins with the same conditions as the first period
start_this_period = dt_init + hist_drop_buffer
end_this_period = dt_init + max_staleness

Expand Down
Loading
Loading