From c68d5842c1fe35c2246f3aec2eafe9bea1a8adab Mon Sep 17 00:00:00 2001 From: James Fulton Date: Mon, 7 Oct 2024 17:26:30 +0000 Subject: [PATCH 1/3] switch to interval_[start/end] + allow forecast to have non-zero first step time --- ocf_data_sampler/config/model.py | 105 ++++++---------- .../select/find_contiguous_time_periods.py | 115 ++++++------------ ocf_data_sampler/select/select_time_slice.py | 56 ++++----- .../torch_datasets/pvnet_uk_regional.py | 40 +++--- tests/config/test_config.py | 30 +++-- .../test_find_contiguous_time_periods.py | 16 +-- tests/select/test_select_time_slice.py | 72 +++++------ .../test_data/configs/pvnet_test_config.yaml | 13 +- tests/test_data/configs/test_config.yaml | 13 +- 9 files changed, 195 insertions(+), 265 deletions(-) diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index 34ca90cf..ece94cf8 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -40,21 +40,44 @@ class General(Base): ) -class DataSourceMixin(Base): - """Mixin class, to add forecast and history minutes""" +class TimeWindowMixin(Base): + """Mixin class, to add interval start, end and resolution minutes""" - forecast_minutes: int = Field( + time_resolution_minutes: int = Field( ..., - ge=0, - description="how many minutes to forecast in the future. ", + gt=0, + description="The temporal resolution of the data in minutes", ) - history_minutes: int = Field( + + interval_end_minutes: int = Field( ..., - ge=0, - description="how many historic minutes to use. ", + description="Data interval ends at `t0 + interval_end_minutes`", + ) + interval_start_minutes: int = Field( + ..., + description="Data interval starts at `t0 + interval_start_minutes`", ) + @model_validator(mode='after') + def check_interval_range(cls, values): + if values.interval_start_minutes > values.interval_end_minutes: + raise ValueError('interval_start_minutes must be <= interval_end_minutes') + return values + + @field_validator("interval_start_minutes") + def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + if v % info.data["time_resolution_minutes"] != 0: + raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes") + return v + + @field_validator("interval_end_minutes") + def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + if v % info.data["time_resolution_minutes"] != 0: + raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes") + return v + + # noinspection PyMethodParameters class DropoutMixin(Base): """Mixin class, to add dropout minutes""" @@ -65,7 +88,12 @@ class DropoutMixin(Base): "negative or zero.", ) - dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample") + dropout_fraction: float = Field( + default=0, + description="Chance of dropout being applied to each sample", + ge=0, + le=1, + ) @field_validator("dropout_timedeltas_minutes") def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]: @@ -75,12 +103,6 @@ def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]: assert m <= 0, "Dropout timedeltas must be negative" return v - @field_validator("dropout_fraction") - def dropout_fraction_valid(cls, v: float) -> float: - """Validate 'dropout_fraction'""" - assert 0 <= v <= 1, "Dropout fraction must be between 0 and 1" - return v - @model_validator(mode="after") def dropout_instructions_consistent(self) -> Self: if self.dropout_fraction == 0: @@ -92,17 +114,7 @@ def dropout_instructions_consistent(self) -> Self: return self -# noinspection PyMethodParameters -class TimeResolutionMixin(Base): - """Time resolution mix in""" - - time_resolution_minutes: int = Field( - ..., - description="The temporal resolution of the data in minutes", - ) - - -class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin): +class Satellite(TimeWindowMixin, DropoutMixin): """Satellite configuration model""" # Todo: remove 'satellite' from names @@ -118,20 +130,15 @@ class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin): description="The number of pixels of the height of the region of interest" " for non-HRV satellite channels.", ) - satellite_image_size_pixels_width: int = Field( ..., description="The number of pixels of the width of the region " "of interest for non-HRV satellite channels.", ) - live_delay_minutes: int = Field( - ..., description="The expected delay in minutes of the satellite data" - ) - # noinspection PyMethodParameters -class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin): +class NWP(TimeWindowMixin, DropoutMixin): """NWP configuration model""" nwp_zarr_path: str | tuple[str] | list[str] = Field( @@ -154,7 +161,6 @@ class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin): " the maximum forecast horizon of the NWP and the requested forecast length.", ) - @field_validator("nwp_provider") def validate_nwp_provider(cls, v: str) -> str: """Validate 'nwp_provider'""" @@ -164,22 +170,6 @@ def validate_nwp_provider(cls, v: str) -> str: raise Exception(message) return v - # Todo: put into time mixin when moving intervals there - @field_validator("forecast_minutes") - def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: - if v % info.data["time_resolution_minutes"] != 0: - message = "Forecast duration must be divisible by time resolution" - logger.error(message) - raise Exception(message) - return v - - @field_validator("history_minutes") - def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: - if v % info.data["time_resolution_minutes"] != 0: - message = "History duration must be divisible by time resolution" - logger.error(message) - raise Exception(message) - return v class MultiNWP(RootModel): @@ -209,26 +199,11 @@ def items(self): # noinspection PyMethodParameters -class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin): +class GSP(TimeWindowMixin, DropoutMixin): """GSP configuration model""" gsp_zarr_path: str = Field(..., description="The path which holds the GSP zarr") - @field_validator("forecast_minutes") - def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: - if v % info.data["time_resolution_minutes"] != 0: - message = "Forecast duration must be divisible by time resolution" - logger.error(message) - raise Exception(message) - return v - - @field_validator("history_minutes") - def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: - if v % info.data["time_resolution_minutes"] != 0: - message = "History duration must be divisible by time resolution" - logger.error(message) - raise Exception(message) - return v # noinspection PyPep8Naming diff --git a/ocf_data_sampler/select/find_contiguous_time_periods.py b/ocf_data_sampler/select/find_contiguous_time_periods.py index 90135139..28f5b3c2 100644 --- a/ocf_data_sampler/select/find_contiguous_time_periods.py +++ b/ocf_data_sampler/select/find_contiguous_time_periods.py @@ -63,16 +63,16 @@ def find_contiguous_time_periods( def trim_contiguous_time_periods( contiguous_time_periods: pd.DataFrame, - history_duration: pd.Timedelta, - forecast_duration: pd.Timedelta, + interval_start: pd.Timedelta, + interval_end: pd.Timedelta, ) -> pd.DataFrame: """Trim the contiguous time periods to allow for history and forecast durations. Args: contiguous_time_periods: DataFrame where each row represents a single time period. The DataFrame must have `start_dt` and `end_dt` columns. - history_duration: Length of the historical slice used for a sample - forecast_duration: Length of the forecast slice used for a sample + interval_start: The start of the interval with respect to t0 + interval_end: The end of the interval with respect to t0 Returns: @@ -80,8 +80,8 @@ def trim_contiguous_time_periods( """ contiguous_time_periods = contiguous_time_periods.copy() - contiguous_time_periods["start_dt"] += history_duration - contiguous_time_periods["end_dt"] -= forecast_duration + contiguous_time_periods["start_dt"] -= interval_start + contiguous_time_periods["end_dt"] -= interval_end valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"] contiguous_time_periods = contiguous_time_periods.loc[valid_mask] @@ -92,16 +92,16 @@ def trim_contiguous_time_periods( def find_contiguous_t0_periods( datetimes: pd.DatetimeIndex, - history_duration: pd.Timedelta, - forecast_duration: pd.Timedelta, + interval_start: pd.Timedelta, + interval_end: pd.Timedelta, sample_period_duration: pd.Timedelta, ) -> pd.DataFrame: """Return a pd.DataFrame where each row records the boundary of a contiguous time period. Args: datetimes: pd.DatetimeIndex. Must be sorted. - history_duration: Length of the historical slice used for each sample - forecast_duration: Length of the forecast slice used for each sample + interval_start: The start of the interval with respect to t0 + interval_end: The end of the interval with respect to t0 sample_period_duration: The sample frequency of the timeseries @@ -109,7 +109,7 @@ def find_contiguous_t0_periods( pd.DataFrame where each row represents a single time period. The pd.DataFrame has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). """ - total_duration = history_duration + forecast_duration + total_duration = interval_end - interval_start contiguous_time_periods = find_contiguous_time_periods( datetimes=datetimes, @@ -119,8 +119,8 @@ def find_contiguous_t0_periods( contiguous_t0_periods = trim_contiguous_time_periods( contiguous_time_periods=contiguous_time_periods, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, ) assert len(contiguous_t0_periods) > 0 @@ -128,92 +128,57 @@ def find_contiguous_t0_periods( return contiguous_t0_periods -def _find_contiguous_t0_periods_nwp( - ds, - history_duration: pd.Timedelta, - forecast_duration: pd.Timedelta, - max_staleness: pd.Timedelta | None = None, - max_dropout: pd.Timedelta = pd.Timedelta(0), - time_dim: str = "init_time_utc", - end_buffer: pd.Timedelta = pd.Timedelta(0), - ): - - assert "step" in ds.coords - # It is possible to use up to this amount of max staleness for the dataset and slice - # required - possible_max_staleness = ( - pd.Timedelta(ds["step"].max().item()) - - forecast_duration - - end_buffer - ) - - # If max_staleness is set to None we set it based on the max step ahead of the input - # forecast data - if max_staleness is None: - max_staleness = possible_max_staleness - else: - # Make sure the max acceptable staleness isn't longer than the max possible - assert max_staleness <= possible_max_staleness - max_staleness = max_staleness - - contiguous_time_periods = find_contiguous_t0_periods_nwp( - datetimes=pd.DatetimeIndex(ds[time_dim]), - history_duration=history_duration, - max_staleness=max_staleness, - max_dropout=max_dropout, - ) - return contiguous_time_periods - - - def find_contiguous_t0_periods_nwp( - datetimes: pd.DatetimeIndex, - history_duration: pd.Timedelta, + init_times: pd.DatetimeIndex, + interval_start: pd.Timedelta, max_staleness: pd.Timedelta, max_dropout: pd.Timedelta = pd.Timedelta(0), + first_forecast_step: pd.Timedelta = pd.Timedelta(0), + ) -> pd.DataFrame: """Get all time periods from the NWP init times which are valid as t0 datetimes. Args: - datetimes: Sorted pd.DatetimeIndex - history_duration: Length of the historical slice used for a sample - max_staleness: Up to how long after an NWP forecast init_time are we willing to use the - forecast. Each init time will only be used up to this t0 time regardless of the forecast - valid time. + init_times: The initialisation times of the available forecasts + interval_start: The start of the desired data interval with respect to t0 + max_staleness: Up to how long after an init time are we willing to use the forecast. Each + init time will only be used up to this t0 time regardless of the forecast valid time. max_dropout: What is the maximum amount of dropout that will be used. This must be <= max_staleness. + first_forecast_step: The timedelta of the first step of the forecast. By default we assume + the first valid time of the forecast is the same as its init time. Returns: pd.DataFrame where each row represents a single time period. The pd.DataFrame has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). """ # Sanity checks. - assert len(datetimes) > 0 - assert datetimes.is_monotonic_increasing - assert datetimes.is_unique - assert history_duration >= pd.Timedelta(0) + assert len(init_times) > 0 + assert init_times.is_monotonic_increasing + assert init_times.is_unique assert max_staleness >= pd.Timedelta(0) - assert max_dropout <= max_staleness + assert pd.Timedelta(0) <= max_dropout <= max_staleness - hist_drop_buffer = max(history_duration, max_dropout) + hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout) # Store contiguous periods contiguous_periods = [] - # Start first period allowing for history slice and max dropout - start_this_period = datetimes[0] + hist_drop_buffer + # Begin the first period allowing for the time to the first_forecast_step, the length of the + # interval sampled from before t0, and the dropout + start_this_period = init_times[0] + hist_drop_buffer # The first forecast is valid up to the max staleness - end_this_period = datetimes[0] + max_staleness - - for dt_init in datetimes[1:]: - # If the previous init time becomes stale before the next init becomes valid whilst also - # considering dropout - then the contiguous period breaks, and new starts with considering - # dropout and history duration - if end_this_period < dt_init + max_dropout: + end_this_period = init_times[0] + max_staleness + + for dt_init in init_times[1:]: + # If the previous init time becomes stale before the next init becomes valid (whilst also + # considering dropout) then the contiguous period breaks + # Else if the previous init time becomes stale before the fist step of the next forecast + # then this also causes a break in the contiguous period + if (end_this_period < dt_init + max(max_dropout, first_forecast_step)): contiguous_periods.append([start_this_period, end_this_period]) - - # And start a new period + # The new period begins with the same conditions as the first period start_this_period = dt_init + hist_drop_buffer end_this_period = dt_init + max_staleness diff --git a/ocf_data_sampler/select/select_time_slice.py b/ocf_data_sampler/select/select_time_slice.py index 65bbde25..2be09ef5 100644 --- a/ocf_data_sampler/select/select_time_slice.py +++ b/ocf_data_sampler/select/select_time_slice.py @@ -39,23 +39,15 @@ def _sel_fillinterp( def select_time_slice( ds: xr.DataArray, t0: pd.Timestamp, + interval_start: pd.Timedelta, + interval_end: pd.Timedelta, sample_period_duration: pd.Timedelta, - history_duration: pd.Timedelta | None = None, - forecast_duration: pd.Timedelta | None = None, - interval_start: pd.Timedelta | None = None, - interval_end: pd.Timedelta | None = None, fill_selection: bool = False, max_steps_gap: int = 0, ): """Select a time slice from a Dataset or DataArray.""" - used_duration = history_duration is not None and forecast_duration is not None used_intervals = interval_start is not None and interval_end is not None - assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied" assert max_steps_gap >= 0, "max_steps_gap must be >= 0 " - - if used_duration: - interval_start = - history_duration - interval_end = forecast_duration if fill_selection and max_steps_gap == 0: _sel = _sel_fillnan @@ -75,11 +67,11 @@ def select_time_slice( def select_time_slice_nwp( - ds: xr.DataArray, + da: xr.DataArray, t0: pd.Timestamp, + interval_start: pd.Timedelta, + interval_end: pd.Timedelta, sample_period_duration: pd.Timedelta, - history_duration: pd.Timedelta, - forecast_duration: pd.Timedelta, dropout_timedeltas: list[pd.Timedelta] | None = None, dropout_frac: float | None = 0, accum_channels: list[str] = [], @@ -92,31 +84,31 @@ def select_time_slice_nwp( ), "dropout timedeltas must be negative" assert len(dropout_timedeltas) >= 1 assert 0 <= dropout_frac <= 1 - _consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0 + consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0 # The accumatation and non-accumulation channels accum_channels = np.intersect1d( - ds[channel_dim_name].values, accum_channels + da[channel_dim_name].values, accum_channels ) non_accum_channels = np.setdiff1d( - ds[channel_dim_name].values, accum_channels + da[channel_dim_name].values, accum_channels ) - start_dt = (t0 - history_duration).ceil(sample_period_duration) - end_dt = (t0 + forecast_duration).ceil(sample_period_duration) + start_dt = (t0 + interval_start).ceil(sample_period_duration) + end_dt = (t0 + interval_end).ceil(sample_period_duration) target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration) # Maybe apply NWP dropout - if _consider_dropout and (np.random.uniform() < dropout_frac): + if consider_dropout and (np.random.uniform() < dropout_frac): dt = np.random.choice(dropout_timedeltas) t0_available = t0 + dt else: t0_available = t0 # Forecasts made up to and including t0 - available_init_times = ds.init_time_utc.sel( + available_init_times = da.init_time_utc.sel( init_time_utc=slice(None, t0_available) ) @@ -139,7 +131,7 @@ def select_time_slice_nwp( step_indexer = xr.DataArray(steps, coords=coords) if len(accum_channels) == 0: - xr_sel = ds.sel(step=step_indexer, init_time_utc=init_time_indexer) + da_sel = da.sel(step=step_indexer, init_time_utc=init_time_indexer) else: # First minimise the size of the dataset we are diffing @@ -149,7 +141,7 @@ def select_time_slice_nwp( min_step = min(steps) max_step = max(steps) + sample_period_duration - xr_min = ds.sel( + da_min = da.sel( { "init_time_utc": unique_init_times, "step": slice(min_step, max_step), @@ -157,28 +149,28 @@ def select_time_slice_nwp( ) # Slice out the data which does not need to be diffed - xr_non_accum = xr_min.sel({channel_dim_name: non_accum_channels}) - xr_sel_non_accum = xr_non_accum.sel( + da_non_accum = da_min.sel({channel_dim_name: non_accum_channels}) + da_sel_non_accum = da_non_accum.sel( step=step_indexer, init_time_utc=init_time_indexer ) # Slice out the channels which need to be diffed - xr_accum = xr_min.sel({channel_dim_name: accum_channels}) + da_accum = da_min.sel({channel_dim_name: accum_channels}) # Take the diff and slice requested data - xr_accum = xr_accum.diff(dim="step", label="lower") - xr_sel_accum = xr_accum.sel(step=step_indexer, init_time_utc=init_time_indexer) + da_accum = da_accum.diff(dim="step", label="lower") + da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer) # Join diffed and non-diffed variables - xr_sel = xr.concat([xr_sel_non_accum, xr_sel_accum], dim=channel_dim_name) + da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=channel_dim_name) # Reorder the variable back to the original order - xr_sel = xr_sel.sel({channel_dim_name: ds[channel_dim_name].values}) + da_sel = da_sel.sel({channel_dim_name: da[channel_dim_name].values}) # Rename the diffed channels - xr_sel[channel_dim_name] = [ + da_sel[channel_dim_name] = [ f"diff_{v}" if v in accum_channels else v - for v in xr_sel[channel_dim_name].values + for v in da_sel[channel_dim_name].values ] - return xr_sel \ No newline at end of file + return da_sel \ No newline at end of file diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index ecd952ae..9179e5f8 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -138,7 +138,7 @@ def find_valid_t0_times( # This is the max staleness we can use considering the max step of the input data max_possible_staleness = ( pd.Timedelta(da["step"].max().item()) - - minutes(nwp_config.forecast_minutes) + - minutes(nwp_config.interval_end_minutes) - end_buffer ) @@ -149,11 +149,15 @@ def find_valid_t0_times( # Make sure the max acceptable staleness isn't longer than the max possible assert max_staleness <= max_possible_staleness + # Find the first forecast step + first_forecast_step = pd.Timedelta(da["step"].min().item()) + time_periods = find_contiguous_t0_periods_nwp( - datetimes=pd.DatetimeIndex(da["init_time_utc"]), - history_duration=minutes(nwp_config.history_minutes), + init_times=pd.DatetimeIndex(da["init_time_utc"]), + interval_start=minutes(nwp_config.interval_start_minutes), max_staleness=max_staleness, max_dropout=max_dropout, + first_forecast_step = first_forecast_step, ) contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods @@ -164,8 +168,8 @@ def find_valid_t0_times( time_periods = find_contiguous_t0_periods( pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]), sample_period_duration=minutes(sat_config.time_resolution_minutes), - history_duration=minutes(sat_config.history_minutes), - forecast_duration=minutes(sat_config.forecast_minutes), + interval_start=minutes(sat_config.interval_start_minutes), + interval_end=minutes(sat_config.interval_end_minutes), ) contiguous_time_periods['sat'] = time_periods @@ -176,8 +180,8 @@ def find_valid_t0_times( time_periods = find_contiguous_t0_periods( pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]), sample_period_duration=minutes(gsp_config.time_resolution_minutes), - history_duration=minutes(gsp_config.history_minutes), - forecast_duration=minutes(gsp_config.forecast_minutes), + interval_start=minutes(gsp_config.interval_start_minutes), + interval_end=minutes(gsp_config.interval_end_minutes), ) contiguous_time_periods['gsp'] = time_periods @@ -279,8 +283,8 @@ def slice_datasets_by_time( da_nwp, t0, sample_period_duration=minutes(nwp_config.time_resolution_minutes), - history_duration=minutes(nwp_config.history_minutes), - forecast_duration=minutes(nwp_config.forecast_minutes), + interval_start=minutes(nwp_config.interval_start_minutes), + interval_end=minutes(nwp_config.interval_end_minutes), dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes), dropout_frac=nwp_config.dropout_fraction, accum_channels=nwp_config.nwp_accum_channels, @@ -294,8 +298,8 @@ def slice_datasets_by_time( datasets_dict["sat"], t0, sample_period_duration=minutes(sat_config.time_resolution_minutes), - interval_start=minutes(-sat_config.history_minutes), - interval_end=minutes(-sat_config.live_delay_minutes), + interval_start=minutes(sat_config.interval_start_minutes), + interval_end=minutes(sat_config.interval_end_minutes), max_steps_gap=2, ) @@ -319,15 +323,15 @@ def slice_datasets_by_time( datasets_dict["gsp"], t0, sample_period_duration=minutes(gsp_config.time_resolution_minutes), - interval_start=minutes(30), - interval_end=minutes(gsp_config.forecast_minutes), + interval_start=minutes(gsp_config.time_resolution_minutes), + interval_end=minutes(gsp_config.interval_end_minutes), ) - + sliced_datasets_dict["gsp"] = select_time_slice( datasets_dict["gsp"], t0, sample_period_duration=minutes(gsp_config.time_resolution_minutes), - interval_start=-minutes(gsp_config.history_minutes), + interval_start=minutes(gsp_config.interval_start_minutes), interval_end=minutes(0), ) @@ -410,14 +414,14 @@ def process_and_combine_datasets( numpy_modalities.append( convert_gsp_to_numpy_batch( da_gsp, - t0_idx=gsp_config.history_minutes / gsp_config.time_resolution_minutes + t0_idx=abs(gsp_config.interval_start_minutes) / gsp_config.time_resolution_minutes ) ) # Make sun coords NumpyBatch datetimes = pd.date_range( - t0-minutes(gsp_config.history_minutes), - t0+minutes(gsp_config.forecast_minutes), + t0+minutes(gsp_config.interval_start_minutes), + t0+minutes(gsp_config.interval_end_minutes), freq=minutes(gsp_config.time_resolution_minutes), ) diff --git a/tests/config/test_config.py b/tests/config/test_config.py index df5d4182..9661f5ed 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -10,13 +10,13 @@ ) -def test_default(): +def test_default_configuration(): """Test default pydantic class""" _ = Configuration() -def test_yaml_load_test_config(test_config_filename): +def test_load_yaml_configuration(test_config_filename): """ Test that yaml loading works for 'test_config.yaml' and fails for an empty .yaml file @@ -56,7 +56,7 @@ def test_yaml_save(test_config_filename): assert test_config == tmp_config -def test_extra_field(): +def test_extra_field_error(): """ Check an extra parameters in config causes error """ @@ -68,27 +68,33 @@ def test_extra_field(): _ = Configuration(**configuration_dict) -def test_incorrect_forecast_minutes(test_config_filename): +def test_incorrect_interval_start_minutes(test_config_filename): """ - Check a forecast length not divisible by time resolution causes error + Check a history length not divisible by time resolution causes error """ configuration = load_yaml_configuration(test_config_filename) - configuration.input_data.nwp['ukv'].forecast_minutes = 1111 - with pytest.raises(Exception, match="duration must be divisible by time resolution"): + configuration.input_data.nwp['ukv'].interval_start_minutes = -1111 + with pytest.raises( + ValueError, + match="interval_start_minutes must be divisible by time_resolution_minutes" + ): _ = Configuration(**configuration.model_dump()) -def test_incorrect_history_minutes(test_config_filename): +def test_incorrect_interval_end_minutes(test_config_filename): """ - Check a history length not divisible by time resolution causes error + Check a forecast length not divisible by time resolution causes error """ configuration = load_yaml_configuration(test_config_filename) - configuration.input_data.nwp['ukv'].history_minutes = 1111 - with pytest.raises(Exception, match="duration must be divisible by time resolution"): + configuration.input_data.nwp['ukv'].interval_end_minutes = 1111 + with pytest.raises( + ValueError, + match="interval_end_minutes must be divisible by time_resolution_minutes" + ): _ = Configuration(**configuration.model_dump()) @@ -103,6 +109,7 @@ def test_incorrect_nwp_provider(test_config_filename): with pytest.raises(Exception, match="NWP provider"): _ = Configuration(**configuration.model_dump()) + def test_incorrect_dropout(test_config_filename): """ Check a dropout timedelta over 0 causes error and 0 doesn't @@ -119,6 +126,7 @@ def test_incorrect_dropout(test_config_filename): configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0] _ = Configuration(**configuration.model_dump()) + def test_incorrect_dropout_fraction(test_config_filename): """ Check dropout fraction outside of range causes error diff --git a/tests/select/test_find_contiguous_time_periods.py b/tests/select/test_find_contiguous_time_periods.py index 9455b829..50243185 100644 --- a/tests/select/test_find_contiguous_time_periods.py +++ b/tests/select/test_find_contiguous_time_periods.py @@ -11,8 +11,8 @@ def test_find_contiguous_t0_periods(): # Create 5-minutely data timestamps freq = pd.Timedelta(5, "min") - history_duration = pd.Timedelta(60, "min") - forecast_duration = pd.Timedelta(15, "min") + interval_start = pd.Timedelta(-60, "min") + interval_end = pd.Timedelta(15, "min") datetimes = ( pd.date_range("2023-01-01 12:00", "2023-01-01 17:00", freq=freq) @@ -21,8 +21,8 @@ def test_find_contiguous_t0_periods(): periods = find_contiguous_t0_periods( datetimes=datetimes, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, sample_period_duration=freq, ) @@ -135,7 +135,7 @@ def test_find_contiguous_t0_periods_nwp(): # Create 3-hourly init times with a few time stamps missing freq = pd.Timedelta(3, "h") - datetimes = ( + init_times = ( pd.date_range("2023-01-01 03:00", "2023-01-02 21:00", freq=freq) .delete([1, 4, 5, 6, 7, 9, 10]) ) @@ -146,13 +146,13 @@ def test_find_contiguous_t0_periods_nwp(): max_dropouts_hr = [0, 0, 0, 0, 3] for i in range(len(expected_results)): - history_duration = pd.Timedelta(history_durations_hr[i], "h") + interval_start = pd.Timedelta(-history_durations_hr[i], "h") max_staleness = pd.Timedelta(max_stalenesses_hr[i], "h") max_dropout = pd.Timedelta(max_dropouts_hr[i], "h") time_periods = find_contiguous_t0_periods_nwp( - datetimes=datetimes, - history_duration=history_duration, + init_times=init_times, + interval_start=interval_start, max_staleness=max_staleness, max_dropout=max_dropout, ) diff --git a/tests/select/test_select_time_slice.py b/tests/select/test_select_time_slice.py index 5b592fef..a43e758e 100644 --- a/tests/select/test_select_time_slice.py +++ b/tests/select/test_select_time_slice.py @@ -55,31 +55,19 @@ def test_select_time_slice(da_sat_like, t0_str): # Slice parameters t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - forecast_duration = pd.Timedelta("0min") - history_duration = pd.Timedelta("60min") + interval_start = pd.Timedelta(-0, "min") + interval_end = pd.Timedelta(60, "min") freq = pd.Timedelta("5min") # Expect to return these timestamps from the selection - expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_datetimes = pd.date_range(t0 +interval_start, t0 + interval_end, freq=freq) - # Make the selection using the `[x]_duration` parameters + # Make the selection sat_sample = select_time_slice( - ds=da_sat_like, + da_sat_like, t0=t0, - history_duration=history_duration, - forecast_duration=forecast_duration, - sample_period_duration=freq, - ) - - # Check the returned times are as expected - assert (sat_sample.time_utc == expected_datetimes).all() - - # Make the selection using the `interval_[x]` parameters - sat_sample = select_time_slice( - ds=da_sat_like, - t0=t0, - interval_start=-history_duration, - interval_end=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, sample_period_duration=freq, ) @@ -93,8 +81,8 @@ def test_select_time_slice_out_of_bounds(da_sat_like, t0_str): # Slice parameters t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - forecast_duration = pd.Timedelta("30min") - history_duration = pd.Timedelta("60min") + interval_start = pd.Timedelta(-30, "min") + interval_end = pd.Timedelta(60, "min") freq = pd.Timedelta("5min") # The data is available between these times @@ -102,14 +90,14 @@ def test_select_time_slice_out_of_bounds(da_sat_like, t0_str): max_time = da_sat_like.time_utc.max() # Expect to return these timestamps from the selection - expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_datetimes = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq) # Make the partially out of bounds selection sat_sample = select_time_slice( - ds=da_sat_like, + da_sat_like, t0=t0, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, sample_period_duration=freq, fill_selection=True ) @@ -138,8 +126,8 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str): # Slice parameters t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - forecast_duration = pd.Timedelta("6H") - history_duration = pd.Timedelta("3H") + interval_start = pd.Timedelta(-6, "H") + interval_end = pd.Timedelta(3, "H") freq = pd.Timedelta("1H") # Make the selection @@ -147,8 +135,8 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str): da_nwp_like, t0, sample_period_duration=freq, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, dropout_timedeltas = None, dropout_frac = 0, accum_channels = [], @@ -156,7 +144,7 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str): ) # Check the target-times are as expected - expected_target_times = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq) assert (da_slice.target_time_utc==expected_target_times).all() # Check the init-times are as expected @@ -172,8 +160,8 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours): """Test the functionality of select_time_slice_nwp with dropout""" t0 = pd.Timestamp("2024-01-02 12:00") - forecast_duration = pd.Timedelta("6H") - history_duration = pd.Timedelta("3H") + interval_start = pd.Timedelta(-6, "H") + interval_end = pd.Timedelta(3, "H") freq = pd.Timedelta("1H") dropout_timedelta = pd.Timedelta(f"-{dropout_hours}H") @@ -181,8 +169,8 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours): da_nwp_like, t0, sample_period_duration=freq, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, dropout_timedeltas = [dropout_timedelta], dropout_frac = 1, accum_channels = [], @@ -190,7 +178,7 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours): ) # Check the target-times are as expected - expected_target_times = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq) assert (da_slice.target_time_utc==expected_target_times).all() # Check the init-times are as expected considering the delay @@ -207,8 +195,8 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): # Slice parameters t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - forecast_duration = pd.Timedelta("6H") - history_duration = pd.Timedelta("3H") + interval_start = pd.Timedelta(-6, "H") + interval_end = pd.Timedelta(3, "H") freq = pd.Timedelta("1H") dropout_timedelta = pd.Timedelta("-2H") @@ -218,8 +206,8 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): da_nwp_like, t0, sample_period_duration=freq, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, dropout_timedeltas=[dropout_timedelta], dropout_frac=1, accum_channels=["dswrf"], @@ -227,7 +215,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): ) # Check the target-times are as expected - expected_target_times = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq) assert (da_slice.target_time_utc==expected_target_times).all() # Check the init-times are as expected considering the delay @@ -254,7 +242,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): init_time_utc=t0_delayed, channel="dswrf", ).diff(dim="step", label="lower") - .sel(step=slice(t0-t0_delayed - history_duration, t0-t0_delayed + forecast_duration)) + .sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end)) ) # Check the values are the same @@ -275,7 +263,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): init_time_utc=t0_delayed, channel="t", ) - .sel(step=slice(t0-t0_delayed - history_duration, t0-t0_delayed + forecast_duration)) + .sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end)) ) # Check the values are the same diff --git a/tests/test_data/configs/pvnet_test_config.yaml b/tests/test_data/configs/pvnet_test_config.yaml index 1625ace4..47294a15 100644 --- a/tests/test_data/configs/pvnet_test_config.yaml +++ b/tests/test_data/configs/pvnet_test_config.yaml @@ -6,8 +6,8 @@ input_data: gsp: gsp_zarr_path: set_in_temp_file - history_minutes: 60 - forecast_minutes: 120 + interval_start_minutes: -60 + interval_end_minutes: 120 time_resolution_minutes: 30 dropout_timedeltas_minutes: null dropout_fraction: 0 @@ -16,8 +16,8 @@ input_data: ukv: nwp_provider: ukv nwp_zarr_path: set_in_temp_file - history_minutes: 60 - forecast_minutes: 120 + interval_start_minutes: -60 + interval_end_minutes: 120 time_resolution_minutes: 60 nwp_channels: - t # 2-metre temperature @@ -29,9 +29,8 @@ input_data: satellite: satellite_zarr_path: set_in_temp_file - history_minutes: 30 - forecast_minutes: 0 - live_delay_minutes: 0 + interval_start_minutes: -30 + interval_end_minutes: 0 time_resolution_minutes: 5 satellite_channels: - IR_016 diff --git a/tests/test_data/configs/test_config.yaml b/tests/test_data/configs/test_config.yaml index 024f2a3b..a08f5f35 100644 --- a/tests/test_data/configs/test_config.yaml +++ b/tests/test_data/configs/test_config.yaml @@ -4,8 +4,8 @@ general: input_data: gsp: gsp_zarr_path: tests/data/gsp/test.zarr - history_minutes: 60 - forecast_minutes: 120 + interval_start_minutes: -60 + interval_end_minutes: 120 time_resolution_minutes: 30 dropout_timedeltas_minutes: [-30] dropout_fraction: 0.1 @@ -17,8 +17,8 @@ input_data: nwp_image_size_pixels_width: 2 nwp_zarr_path: tests/data/nwp_data/test.zarr nwp_provider: "ukv" - history_minutes: 60 - forecast_minutes: 120 + interval_start_minutes: -60 + interval_end_minutes: 120 time_resolution_minutes: 60 dropout_timedeltas_minutes: [-180] dropout_fraction: 1.0 @@ -30,6 +30,5 @@ input_data: satellite_image_size_pixels_width: 24 satellite_zarr_path: tests/data/sat_data.zarr time_resolution_minutes: 15 - history_minutes: 60 - forecast_minutes: 0 - live_delay_minutes: 0 \ No newline at end of file + interval_start_minutes: -60 + interval_end_minutes: 0 From c7a33a7aa87e0f233d704ff68a3d7f5fceff6e1b Mon Sep 17 00:00:00 2001 From: James Fulton Date: Mon, 7 Oct 2024 17:34:14 +0000 Subject: [PATCH 2/3] fix test --- tests/config/test_config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 9661f5ed..366dc1da 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -135,11 +135,12 @@ def test_incorrect_dropout_fraction(test_config_filename): configuration = load_yaml_configuration(test_config_filename) configuration.input_data.nwp['ukv'].dropout_fraction= 1.1 - with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"): + + with pytest.raises(ValidationError, match="Input should be less than or equal to 1"): _ = Configuration(**configuration.model_dump()) configuration.input_data.nwp['ukv'].dropout_fraction= -0.1 - with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"): + with pytest.raises(ValidationError, match="Input should be greater than or equal to 0"): _ = Configuration(**configuration.model_dump()) From 4742ffcd4adc07e8a4e177bbb05ab40a700eda98 Mon Sep 17 00:00:00 2001 From: James Fulton Date: Fri, 1 Nov 2024 10:14:05 +0000 Subject: [PATCH 3/3] remove unneeded param check --- ocf_data_sampler/select/select_time_slice.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ocf_data_sampler/select/select_time_slice.py b/ocf_data_sampler/select/select_time_slice.py index 2be09ef5..8d5c31e9 100644 --- a/ocf_data_sampler/select/select_time_slice.py +++ b/ocf_data_sampler/select/select_time_slice.py @@ -46,7 +46,6 @@ def select_time_slice( max_steps_gap: int = 0, ): """Select a time slice from a Dataset or DataArray.""" - used_intervals = interval_start is not None and interval_end is not None assert max_steps_gap >= 0, "max_steps_gap must be >= 0 " if fill_selection and max_steps_gap == 0: