diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py index 0e13514..0ec5749 100644 --- a/ocf_data_sampler/config/model.py +++ b/ocf_data_sampler/config/model.py @@ -14,7 +14,8 @@ from typing import Dict, List, Optional from typing_extensions import Self -from pydantic import BaseModel, Field, RootModel, field_validator, model_validator +from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator + from ocf_data_sampler.constants import NWP_PROVIDERS logger = logging.getLogger(__name__) @@ -40,6 +41,45 @@ class General(Base): ) +class TimeWindowMixin(Base): + """Mixin class, to add interval start, end and resolution minutes""" + + time_resolution_minutes: int = Field( + ..., + gt=0, + description="The temporal resolution of the data in minutes", + ) + + interval_start_minutes: int = Field( + ..., + description="Data interval starts at `t0 + interval_start_minutes`", + ) + + interval_end_minutes: int = Field( + ..., + description="Data interval ends at `t0 + interval_end_minutes`", + ) + + @model_validator(mode='after') + def check_interval_range(cls, values): + if values.interval_start_minutes > values.interval_end_minutes: + raise ValueError('interval_start_minutes must be <= interval_end_minutes') + return values + + @field_validator("interval_start_minutes") + def interval_start_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + if v % info.data["time_resolution_minutes"] != 0: + raise ValueError("interval_start_minutes must be divisible by time_resolution_minutes") + return v + + @field_validator("interval_end_minutes") + def interval_end_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + if v % info.data["time_resolution_minutes"] != 0: + raise ValueError("interval_end_minutes must be divisible by time_resolution_minutes") + return v + + + # noinspection PyMethodParameters class DropoutMixin(Base): """Mixin class, to add dropout minutes""" @@ -76,54 +116,18 @@ def dropout_instructions_consistent(self) -> Self: return self -# noinspection PyMethodParameters -class TimeWindowMixin(Base): - """Time resolution mix in""" - - time_resolution_minutes: int = Field( - ..., - gt=0, - description="The temporal resolution of the data in minutes", - ) - - forecast_minutes: int = Field( - ..., - ge=0, - description="how many minutes to forecast in the future", - ) - history_minutes: int = Field( - ..., - ge=0, - description="how many historic minutes to use", - ) - - @field_validator("forecast_minutes") - def forecast_minutes_divide_by_time_resolution(cls, v, values) -> int: - if v % values.data["time_resolution_minutes"] != 0: - message = "Forecast duration must be divisible by time resolution" - logger.error(message) - raise Exception(message) - return v - - @field_validator("history_minutes") - def history_minutes_divide_by_time_resolution(cls, v, values) -> int: - if v % values.data["time_resolution_minutes"] != 0: - message = "History duration must be divisible by time resolution" - logger.error(message) - raise Exception(message) - return v - - class SpatialWindowMixin(Base): """Mixin class, to add path and image size""" image_size_pixels_height: int = Field( ..., + ge=0, description="The number of pixels of the height of the region of interest", ) image_size_pixels_width: int = Field( ..., + ge=0, description="The number of pixels of the width of the region of interest", ) @@ -140,10 +144,6 @@ class Satellite(TimeWindowMixin, DropoutMixin, SpatialWindowMixin): ..., description="the satellite channels that are used" ) - live_delay_minutes: int = Field( - ..., description="The expected delay in minutes of the satellite data" - ) - # noinspection PyMethodParameters class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin): @@ -169,6 +169,7 @@ class NWP(TimeWindowMixin, DropoutMixin, SpatialWindowMixin): " the maximum forecast horizon of the NWP and the requested forecast length.", ) + @field_validator("provider") def validate_provider(cls, v: str) -> str: """Validate 'provider'""" @@ -227,11 +228,10 @@ class Site(TimeWindowMixin, DropoutMixin): # TODO validate the csv for metadata + # noinspection PyPep8Naming class InputData(Base): - """ - Input data model. - """ + """Input data model""" satellite: Optional[Satellite] = None nwp: Optional[MultiNWP] = None diff --git a/ocf_data_sampler/select/find_contiguous_time_periods.py b/ocf_data_sampler/select/find_contiguous_time_periods.py index 9013513..28f5b3c 100644 --- a/ocf_data_sampler/select/find_contiguous_time_periods.py +++ b/ocf_data_sampler/select/find_contiguous_time_periods.py @@ -63,16 +63,16 @@ def find_contiguous_time_periods( def trim_contiguous_time_periods( contiguous_time_periods: pd.DataFrame, - history_duration: pd.Timedelta, - forecast_duration: pd.Timedelta, + interval_start: pd.Timedelta, + interval_end: pd.Timedelta, ) -> pd.DataFrame: """Trim the contiguous time periods to allow for history and forecast durations. Args: contiguous_time_periods: DataFrame where each row represents a single time period. The DataFrame must have `start_dt` and `end_dt` columns. - history_duration: Length of the historical slice used for a sample - forecast_duration: Length of the forecast slice used for a sample + interval_start: The start of the interval with respect to t0 + interval_end: The end of the interval with respect to t0 Returns: @@ -80,8 +80,8 @@ def trim_contiguous_time_periods( """ contiguous_time_periods = contiguous_time_periods.copy() - contiguous_time_periods["start_dt"] += history_duration - contiguous_time_periods["end_dt"] -= forecast_duration + contiguous_time_periods["start_dt"] -= interval_start + contiguous_time_periods["end_dt"] -= interval_end valid_mask = contiguous_time_periods["start_dt"] <= contiguous_time_periods["end_dt"] contiguous_time_periods = contiguous_time_periods.loc[valid_mask] @@ -92,16 +92,16 @@ def trim_contiguous_time_periods( def find_contiguous_t0_periods( datetimes: pd.DatetimeIndex, - history_duration: pd.Timedelta, - forecast_duration: pd.Timedelta, + interval_start: pd.Timedelta, + interval_end: pd.Timedelta, sample_period_duration: pd.Timedelta, ) -> pd.DataFrame: """Return a pd.DataFrame where each row records the boundary of a contiguous time period. Args: datetimes: pd.DatetimeIndex. Must be sorted. - history_duration: Length of the historical slice used for each sample - forecast_duration: Length of the forecast slice used for each sample + interval_start: The start of the interval with respect to t0 + interval_end: The end of the interval with respect to t0 sample_period_duration: The sample frequency of the timeseries @@ -109,7 +109,7 @@ def find_contiguous_t0_periods( pd.DataFrame where each row represents a single time period. The pd.DataFrame has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). """ - total_duration = history_duration + forecast_duration + total_duration = interval_end - interval_start contiguous_time_periods = find_contiguous_time_periods( datetimes=datetimes, @@ -119,8 +119,8 @@ def find_contiguous_t0_periods( contiguous_t0_periods = trim_contiguous_time_periods( contiguous_time_periods=contiguous_time_periods, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, ) assert len(contiguous_t0_periods) > 0 @@ -128,92 +128,57 @@ def find_contiguous_t0_periods( return contiguous_t0_periods -def _find_contiguous_t0_periods_nwp( - ds, - history_duration: pd.Timedelta, - forecast_duration: pd.Timedelta, - max_staleness: pd.Timedelta | None = None, - max_dropout: pd.Timedelta = pd.Timedelta(0), - time_dim: str = "init_time_utc", - end_buffer: pd.Timedelta = pd.Timedelta(0), - ): - - assert "step" in ds.coords - # It is possible to use up to this amount of max staleness for the dataset and slice - # required - possible_max_staleness = ( - pd.Timedelta(ds["step"].max().item()) - - forecast_duration - - end_buffer - ) - - # If max_staleness is set to None we set it based on the max step ahead of the input - # forecast data - if max_staleness is None: - max_staleness = possible_max_staleness - else: - # Make sure the max acceptable staleness isn't longer than the max possible - assert max_staleness <= possible_max_staleness - max_staleness = max_staleness - - contiguous_time_periods = find_contiguous_t0_periods_nwp( - datetimes=pd.DatetimeIndex(ds[time_dim]), - history_duration=history_duration, - max_staleness=max_staleness, - max_dropout=max_dropout, - ) - return contiguous_time_periods - - - def find_contiguous_t0_periods_nwp( - datetimes: pd.DatetimeIndex, - history_duration: pd.Timedelta, + init_times: pd.DatetimeIndex, + interval_start: pd.Timedelta, max_staleness: pd.Timedelta, max_dropout: pd.Timedelta = pd.Timedelta(0), + first_forecast_step: pd.Timedelta = pd.Timedelta(0), + ) -> pd.DataFrame: """Get all time periods from the NWP init times which are valid as t0 datetimes. Args: - datetimes: Sorted pd.DatetimeIndex - history_duration: Length of the historical slice used for a sample - max_staleness: Up to how long after an NWP forecast init_time are we willing to use the - forecast. Each init time will only be used up to this t0 time regardless of the forecast - valid time. + init_times: The initialisation times of the available forecasts + interval_start: The start of the desired data interval with respect to t0 + max_staleness: Up to how long after an init time are we willing to use the forecast. Each + init time will only be used up to this t0 time regardless of the forecast valid time. max_dropout: What is the maximum amount of dropout that will be used. This must be <= max_staleness. + first_forecast_step: The timedelta of the first step of the forecast. By default we assume + the first valid time of the forecast is the same as its init time. Returns: pd.DataFrame where each row represents a single time period. The pd.DataFrame has two columns: `start_dt` and `end_dt` (where 'dt' is short for 'datetime'). """ # Sanity checks. - assert len(datetimes) > 0 - assert datetimes.is_monotonic_increasing - assert datetimes.is_unique - assert history_duration >= pd.Timedelta(0) + assert len(init_times) > 0 + assert init_times.is_monotonic_increasing + assert init_times.is_unique assert max_staleness >= pd.Timedelta(0) - assert max_dropout <= max_staleness + assert pd.Timedelta(0) <= max_dropout <= max_staleness - hist_drop_buffer = max(history_duration, max_dropout) + hist_drop_buffer = max(first_forecast_step-interval_start, max_dropout) # Store contiguous periods contiguous_periods = [] - # Start first period allowing for history slice and max dropout - start_this_period = datetimes[0] + hist_drop_buffer + # Begin the first period allowing for the time to the first_forecast_step, the length of the + # interval sampled from before t0, and the dropout + start_this_period = init_times[0] + hist_drop_buffer # The first forecast is valid up to the max staleness - end_this_period = datetimes[0] + max_staleness - - for dt_init in datetimes[1:]: - # If the previous init time becomes stale before the next init becomes valid whilst also - # considering dropout - then the contiguous period breaks, and new starts with considering - # dropout and history duration - if end_this_period < dt_init + max_dropout: + end_this_period = init_times[0] + max_staleness + + for dt_init in init_times[1:]: + # If the previous init time becomes stale before the next init becomes valid (whilst also + # considering dropout) then the contiguous period breaks + # Else if the previous init time becomes stale before the fist step of the next forecast + # then this also causes a break in the contiguous period + if (end_this_period < dt_init + max(max_dropout, first_forecast_step)): contiguous_periods.append([start_this_period, end_this_period]) - - # And start a new period + # The new period begins with the same conditions as the first period start_this_period = dt_init + hist_drop_buffer end_this_period = dt_init + max_staleness diff --git a/ocf_data_sampler/select/select_time_slice.py b/ocf_data_sampler/select/select_time_slice.py index 65bbde2..8d5c31e 100644 --- a/ocf_data_sampler/select/select_time_slice.py +++ b/ocf_data_sampler/select/select_time_slice.py @@ -39,23 +39,14 @@ def _sel_fillinterp( def select_time_slice( ds: xr.DataArray, t0: pd.Timestamp, + interval_start: pd.Timedelta, + interval_end: pd.Timedelta, sample_period_duration: pd.Timedelta, - history_duration: pd.Timedelta | None = None, - forecast_duration: pd.Timedelta | None = None, - interval_start: pd.Timedelta | None = None, - interval_end: pd.Timedelta | None = None, fill_selection: bool = False, max_steps_gap: int = 0, ): """Select a time slice from a Dataset or DataArray.""" - used_duration = history_duration is not None and forecast_duration is not None - used_intervals = interval_start is not None and interval_end is not None - assert used_duration ^ used_intervals, "Either durations, or intervals must be supplied" assert max_steps_gap >= 0, "max_steps_gap must be >= 0 " - - if used_duration: - interval_start = - history_duration - interval_end = forecast_duration if fill_selection and max_steps_gap == 0: _sel = _sel_fillnan @@ -75,11 +66,11 @@ def select_time_slice( def select_time_slice_nwp( - ds: xr.DataArray, + da: xr.DataArray, t0: pd.Timestamp, + interval_start: pd.Timedelta, + interval_end: pd.Timedelta, sample_period_duration: pd.Timedelta, - history_duration: pd.Timedelta, - forecast_duration: pd.Timedelta, dropout_timedeltas: list[pd.Timedelta] | None = None, dropout_frac: float | None = 0, accum_channels: list[str] = [], @@ -92,31 +83,31 @@ def select_time_slice_nwp( ), "dropout timedeltas must be negative" assert len(dropout_timedeltas) >= 1 assert 0 <= dropout_frac <= 1 - _consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0 + consider_dropout = (dropout_timedeltas is not None) and dropout_frac > 0 # The accumatation and non-accumulation channels accum_channels = np.intersect1d( - ds[channel_dim_name].values, accum_channels + da[channel_dim_name].values, accum_channels ) non_accum_channels = np.setdiff1d( - ds[channel_dim_name].values, accum_channels + da[channel_dim_name].values, accum_channels ) - start_dt = (t0 - history_duration).ceil(sample_period_duration) - end_dt = (t0 + forecast_duration).ceil(sample_period_duration) + start_dt = (t0 + interval_start).ceil(sample_period_duration) + end_dt = (t0 + interval_end).ceil(sample_period_duration) target_times = pd.date_range(start_dt, end_dt, freq=sample_period_duration) # Maybe apply NWP dropout - if _consider_dropout and (np.random.uniform() < dropout_frac): + if consider_dropout and (np.random.uniform() < dropout_frac): dt = np.random.choice(dropout_timedeltas) t0_available = t0 + dt else: t0_available = t0 # Forecasts made up to and including t0 - available_init_times = ds.init_time_utc.sel( + available_init_times = da.init_time_utc.sel( init_time_utc=slice(None, t0_available) ) @@ -139,7 +130,7 @@ def select_time_slice_nwp( step_indexer = xr.DataArray(steps, coords=coords) if len(accum_channels) == 0: - xr_sel = ds.sel(step=step_indexer, init_time_utc=init_time_indexer) + da_sel = da.sel(step=step_indexer, init_time_utc=init_time_indexer) else: # First minimise the size of the dataset we are diffing @@ -149,7 +140,7 @@ def select_time_slice_nwp( min_step = min(steps) max_step = max(steps) + sample_period_duration - xr_min = ds.sel( + da_min = da.sel( { "init_time_utc": unique_init_times, "step": slice(min_step, max_step), @@ -157,28 +148,28 @@ def select_time_slice_nwp( ) # Slice out the data which does not need to be diffed - xr_non_accum = xr_min.sel({channel_dim_name: non_accum_channels}) - xr_sel_non_accum = xr_non_accum.sel( + da_non_accum = da_min.sel({channel_dim_name: non_accum_channels}) + da_sel_non_accum = da_non_accum.sel( step=step_indexer, init_time_utc=init_time_indexer ) # Slice out the channels which need to be diffed - xr_accum = xr_min.sel({channel_dim_name: accum_channels}) + da_accum = da_min.sel({channel_dim_name: accum_channels}) # Take the diff and slice requested data - xr_accum = xr_accum.diff(dim="step", label="lower") - xr_sel_accum = xr_accum.sel(step=step_indexer, init_time_utc=init_time_indexer) + da_accum = da_accum.diff(dim="step", label="lower") + da_sel_accum = da_accum.sel(step=step_indexer, init_time_utc=init_time_indexer) # Join diffed and non-diffed variables - xr_sel = xr.concat([xr_sel_non_accum, xr_sel_accum], dim=channel_dim_name) + da_sel = xr.concat([da_sel_non_accum, da_sel_accum], dim=channel_dim_name) # Reorder the variable back to the original order - xr_sel = xr_sel.sel({channel_dim_name: ds[channel_dim_name].values}) + da_sel = da_sel.sel({channel_dim_name: da[channel_dim_name].values}) # Rename the diffed channels - xr_sel[channel_dim_name] = [ + da_sel[channel_dim_name] = [ f"diff_{v}" if v in accum_channels else v - for v in xr_sel[channel_dim_name].values + for v in da_sel[channel_dim_name].values ] - return xr_sel \ No newline at end of file + return da_sel \ No newline at end of file diff --git a/ocf_data_sampler/select/time_slice_for_dataset.py b/ocf_data_sampler/select/time_slice_for_dataset.py index b7905fa..1247ae6 100644 --- a/ocf_data_sampler/select/time_slice_for_dataset.py +++ b/ocf_data_sampler/select/time_slice_for_dataset.py @@ -4,7 +4,7 @@ from ocf_data_sampler.config import Configuration from ocf_data_sampler.select.dropout import draw_dropout_time, apply_dropout_time from ocf_data_sampler.select.select_time_slice import select_time_slice_nwp, select_time_slice -from ocf_data_sampler.time_functions import minutes +from ocf_data_sampler.utils import minutes def slice_datasets_by_time( @@ -23,19 +23,19 @@ def slice_datasets_by_time( sliced_datasets_dict = {} if "nwp" in datasets_dict: - + sliced_datasets_dict["nwp"] = {} - + for nwp_key, da_nwp in datasets_dict["nwp"].items(): - + nwp_config = config.input_data.nwp[nwp_key] sliced_datasets_dict["nwp"][nwp_key] = select_time_slice_nwp( da_nwp, t0, sample_period_duration=minutes(nwp_config.time_resolution_minutes), - history_duration=minutes(nwp_config.history_minutes), - forecast_duration=minutes(nwp_config.forecast_minutes), + interval_start=minutes(nwp_config.interval_start_minutes), + interval_end=minutes(nwp_config.interval_end_minutes), dropout_timedeltas=minutes(nwp_config.dropout_timedeltas_minutes), dropout_frac=nwp_config.dropout_fraction, accum_channels=nwp_config.accum_channels, @@ -49,8 +49,8 @@ def slice_datasets_by_time( datasets_dict["sat"], t0, sample_period_duration=minutes(sat_config.time_resolution_minutes), - interval_start=minutes(-sat_config.history_minutes), - interval_end=minutes(-sat_config.live_delay_minutes), + interval_start=minutes(sat_config.interval_start_minutes), + interval_end=minutes(sat_config.interval_end_minutes), max_steps_gap=2, ) @@ -74,15 +74,15 @@ def slice_datasets_by_time( datasets_dict["gsp"], t0, sample_period_duration=minutes(gsp_config.time_resolution_minutes), - interval_start=minutes(30), - interval_end=minutes(gsp_config.forecast_minutes), + interval_start=minutes(gsp_config.time_resolution_minutes), + interval_end=minutes(gsp_config.interval_end_minutes), ) - + sliced_datasets_dict["gsp"] = select_time_slice( datasets_dict["gsp"], t0, sample_period_duration=minutes(gsp_config.time_resolution_minutes), - interval_start=-minutes(gsp_config.history_minutes), + interval_start=minutes(gsp_config.interval_start_minutes), interval_end=minutes(0), ) @@ -94,9 +94,10 @@ def slice_datasets_by_time( ) sliced_datasets_dict["gsp"] = apply_dropout_time( - sliced_datasets_dict["gsp"], gsp_dropout_time + sliced_datasets_dict["gsp"], + gsp_dropout_time ) - + if "site" in datasets_dict: site_config = config.input_data.site @@ -104,8 +105,8 @@ def slice_datasets_by_time( datasets_dict["site"], t0, sample_period_duration=minutes(site_config.time_resolution_minutes), - interval_start=-minutes(site_config.history_minutes), - interval_end=minutes(site_config.forecast_minutes), + interval_start=minutes(site_config.interval_start_minutes), + interval_end=minutes(site_config.interval_end_minutes), ) # Randomly sample dropout diff --git a/ocf_data_sampler/torch_datasets/Readme.md b/ocf_data_sampler/torch_datasets/README.md similarity index 96% rename from ocf_data_sampler/torch_datasets/Readme.md rename to ocf_data_sampler/torch_datasets/README.md index 606de43..8abc2bb 100644 --- a/ocf_data_sampler/torch_datasets/Readme.md +++ b/ocf_data_sampler/torch_datasets/README.md @@ -4,7 +4,7 @@ The aim of this folder is to create torch datasets which can easily be used in o ## PVNet UK Regional -This pipline is for creating GSP predictions which we have used in our PVNet model. +This dataset is for creating GSP predictions which we have used in our PVNet model. ### Init diff --git a/ocf_data_sampler/torch_datasets/process_and_combine.py b/ocf_data_sampler/torch_datasets/process_and_combine.py index 334e0a6..a732ef0 100644 --- a/ocf_data_sampler/torch_datasets/process_and_combine.py +++ b/ocf_data_sampler/torch_datasets/process_and_combine.py @@ -15,7 +15,7 @@ from ocf_data_sampler.numpy_batch.nwp import NWPBatchKey from ocf_data_sampler.select.geospatial import osgb_to_lon_lat from ocf_data_sampler.select.location import Location -from ocf_data_sampler.time_functions import minutes +from ocf_data_sampler.utils import minutes def process_and_combine_datasets( @@ -23,7 +23,7 @@ def process_and_combine_datasets( config: Configuration, t0: pd.Timestamp, location: Location, - sun_position_key: str = 'gsp' + target_key: str = 'gsp' ) -> dict: """Normalize and convert data to numpy arrays""" @@ -58,7 +58,8 @@ def process_and_combine_datasets( numpy_modalities.append( convert_gsp_to_numpy_batch( - da_gsp, t0_idx=gsp_config.history_minutes // gsp_config.time_resolution_minutes + da_gsp, + t0_idx=-gsp_config.interval_start_minutes / gsp_config.time_resolution_minutes ) ) @@ -80,34 +81,32 @@ def process_and_combine_datasets( numpy_modalities.append( convert_site_to_numpy_batch( - da_sites, t0_idx=site_config.history_minutes / site_config.time_resolution_minutes + da_sites, t0_idx=-site_config.interval_start_minutes / site_config.time_resolution_minutes ) ) - if sun_position_key == 'gsp': + if target_key == 'gsp': # Make sun coords NumpyBatch datetimes = pd.date_range( - t0 - minutes(gsp_config.history_minutes), - t0 + minutes(gsp_config.forecast_minutes), + t0+minutes(gsp_config.interval_start_minutes), + t0+minutes(gsp_config.interval_end_minutes), freq=minutes(gsp_config.time_resolution_minutes), ) lon, lat = osgb_to_lon_lat(location.x, location.y) - key_prefix = "gsp" - elif sun_position_key == 'site': + elif target_key == 'site': # Make sun coords NumpyBatch datetimes = pd.date_range( - t0 - minutes(site_config.history_minutes), - t0 + minutes(site_config.forecast_minutes), + t0+minutes(site_config.interval_start_minutes), + t0+minutes(site_config.interval_end_minutes), freq=minutes(site_config.time_resolution_minutes), ) lon, lat = location.x, location.y - key_prefix = "site" numpy_modalities.append( - make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=key_prefix) + make_sun_position_numpy_batch(datetimes, lon, lat, key_prefix=target_key) ) # Combine all the modalities and fill NaNs diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index e9c1c42..1f6d435 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -9,7 +9,7 @@ from ocf_data_sampler.config import Configuration, load_yaml_configuration from ocf_data_sampler.load.load_dataset import get_dataset_dict from ocf_data_sampler.select import fill_time_periods, Location, slice_datasets_by_space, slice_datasets_by_time -from ocf_data_sampler.time_functions import minutes +from ocf_data_sampler.utils import minutes from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods diff --git a/ocf_data_sampler/torch_datasets/site.py b/ocf_data_sampler/torch_datasets/site.py index b92f821..3ec2fc3 100644 --- a/ocf_data_sampler/torch_datasets/site.py +++ b/ocf_data_sampler/torch_datasets/site.py @@ -14,7 +14,7 @@ intersection_of_multiple_dataframes_of_periods, slice_datasets_by_time, slice_datasets_by_space ) -from ocf_data_sampler.time_functions import minutes +from ocf_data_sampler.utils import minutes from ocf_data_sampler.torch_datasets.process_and_combine import process_and_combine_datasets, compute from ocf_data_sampler.torch_datasets.valid_time_periods import find_valid_time_periods @@ -22,8 +22,8 @@ def find_valid_t0_and_site_ids( - datasets_dict: dict, - config: Configuration, + datasets_dict: dict, + config: Configuration, ) -> pd.DataFrame: """Find the t0 times where all of the requested input data is available @@ -57,8 +57,8 @@ def find_valid_t0_and_site_ids( time_periods = find_contiguous_t0_periods( pd.DatetimeIndex(site["time_utc"]), sample_period_duration=minutes(site_config.time_resolution_minutes), - history_duration=minutes(site_config.history_minutes), - forecast_duration=minutes(site_config.forecast_minutes), + interval_start=minutes(site_config.interval_start_minutes), + interval_end=minutes(site_config.interval_end_minutes), ) valid_time_periods_per_site = intersection_of_multiple_dataframes_of_periods( [valid_time_periods, time_periods] @@ -100,10 +100,10 @@ def get_locations(site_xr: xr.Dataset): class SitesDataset(Dataset): def __init__( - self, - config_filename: str, - start_time: str | None = None, - end_time: str | None = None, + self, + config_filename: str, + start_time: str | None = None, + end_time: str | None = None, ): """A torch Dataset for creating PVNet Site samples @@ -154,7 +154,7 @@ def _get_sample(self, t0: pd.Timestamp, location: Location) -> dict: sample_dict = slice_datasets_by_time(sample_dict, t0, self.config) sample_dict = compute(sample_dict) - sample = process_and_combine_datasets(sample_dict, self.config, t0, location, sun_position_key='site') + sample = process_and_combine_datasets(sample_dict, self.config, t0, location, target_key='site') return sample diff --git a/ocf_data_sampler/torch_datasets/valid_time_periods.py b/ocf_data_sampler/torch_datasets/valid_time_periods.py index 6c9b986..2212e6f 100644 --- a/ocf_data_sampler/torch_datasets/valid_time_periods.py +++ b/ocf_data_sampler/torch_datasets/valid_time_periods.py @@ -2,9 +2,13 @@ import pandas as pd from ocf_data_sampler.config import Configuration -from ocf_data_sampler.select.find_contiguous_time_periods import find_contiguous_t0_periods_nwp, \ - find_contiguous_t0_periods, intersection_of_multiple_dataframes_of_periods -from ocf_data_sampler.time_functions import minutes +from ocf_data_sampler.select.find_contiguous_time_periods import ( + find_contiguous_t0_periods_nwp, + find_contiguous_t0_periods, + intersection_of_multiple_dataframes_of_periods, +) +from ocf_data_sampler.utils import minutes + def find_valid_time_periods( @@ -46,7 +50,7 @@ def find_valid_time_periods( # This is the max staleness we can use considering the max step of the input data max_possible_staleness = ( pd.Timedelta(da["step"].max().item()) - - minutes(nwp_config.forecast_minutes) + - minutes(nwp_config.interval_end_minutes) - end_buffer ) @@ -56,12 +60,16 @@ def find_valid_time_periods( else: # Make sure the max acceptable staleness isn't longer than the max possible assert max_staleness <= max_possible_staleness + + # Find the first forecast step + first_forecast_step = pd.Timedelta(da["step"].min().item()) time_periods = find_contiguous_t0_periods_nwp( - datetimes=pd.DatetimeIndex(da["init_time_utc"]), - history_duration=minutes(nwp_config.history_minutes), + init_times=pd.DatetimeIndex(da["init_time_utc"]), + interval_start=minutes(nwp_config.interval_start_minutes), max_staleness=max_staleness, max_dropout=max_dropout, + first_forecast_step = first_forecast_step, ) contiguous_time_periods[f'nwp_{nwp_key}'] = time_periods @@ -72,8 +80,8 @@ def find_valid_time_periods( time_periods = find_contiguous_t0_periods( pd.DatetimeIndex(datasets_dict["sat"]["time_utc"]), sample_period_duration=minutes(sat_config.time_resolution_minutes), - history_duration=minutes(sat_config.history_minutes), - forecast_duration=minutes(sat_config.forecast_minutes), + interval_start=minutes(sat_config.interval_start_minutes), + interval_end=minutes(sat_config.interval_end_minutes), ) contiguous_time_periods['sat'] = time_periods @@ -84,8 +92,8 @@ def find_valid_time_periods( time_periods = find_contiguous_t0_periods( pd.DatetimeIndex(datasets_dict["gsp"]["time_utc"]), sample_period_duration=minutes(gsp_config.time_resolution_minutes), - history_duration=minutes(gsp_config.history_minutes), - forecast_duration=minutes(gsp_config.forecast_minutes), + interval_start=minutes(gsp_config.interval_start_minutes), + interval_end=minutes(gsp_config.interval_end_minutes), ) contiguous_time_periods['gsp'] = time_periods @@ -105,4 +113,4 @@ def find_valid_time_periods( if len(valid_time_periods) == 0: raise ValueError(f"No valid time periods found, {contiguous_time_periods=}") - return valid_time_periods + return valid_time_periods \ No newline at end of file diff --git a/ocf_data_sampler/time_functions.py b/ocf_data_sampler/utils.py similarity index 71% rename from ocf_data_sampler/time_functions.py rename to ocf_data_sampler/utils.py index bbb09fd..ed37cb2 100644 --- a/ocf_data_sampler/time_functions.py +++ b/ocf_data_sampler/utils.py @@ -7,5 +7,4 @@ def minutes(minutes: int | list[float]) -> pd.Timedelta | pd.TimedeltaIndex: Args: minutes: the number of minutes, single value or list """ - minutes_delta = pd.to_timedelta(minutes, unit="m") - return minutes_delta + return pd.to_timedelta(minutes, unit="m") diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 2a606d4..7cc9b09 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -68,27 +68,33 @@ def test_extra_field_error(): _ = Configuration(**configuration_dict) -def test_incorrect_forecast_minutes(test_config_filename): +def test_incorrect_interval_start_minutes(test_config_filename): """ - Check a forecast length not divisible by time resolution causes error + Check a history length not divisible by time resolution causes error """ configuration = load_yaml_configuration(test_config_filename) - configuration.input_data.nwp['ukv'].forecast_minutes = 1111 - with pytest.raises(Exception, match="duration must be divisible by time resolution"): + configuration.input_data.nwp['ukv'].interval_start_minutes = -1111 + with pytest.raises( + ValueError, + match="interval_start_minutes must be divisible by time_resolution_minutes" + ): _ = Configuration(**configuration.model_dump()) -def test_incorrect_history_minutes(test_config_filename): +def test_incorrect_interval_end_minutes(test_config_filename): """ - Check a history length not divisible by time resolution causes error + Check a forecast length not divisible by time resolution causes error """ configuration = load_yaml_configuration(test_config_filename) - configuration.input_data.nwp['ukv'].history_minutes = 1111 - with pytest.raises(Exception, match="duration must be divisible by time resolution"): + configuration.input_data.nwp['ukv'].interval_end_minutes = 1111 + with pytest.raises( + ValueError, + match="interval_end_minutes must be divisible by time_resolution_minutes" + ): _ = Configuration(**configuration.model_dump()) diff --git a/tests/conftest.py b/tests/conftest.py index 2964d99..06e6348 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -250,11 +250,13 @@ def data_sites() -> Site: generation.to_netcdf(filename) meta_df.to_csv(filename_csv) - site = Site(file_path=filename, - metadata_file_path=filename_csv, - time_resolution_minutes=30, - forecast_minutes=60, - history_minutes=30) + site = Site( + file_path=filename, + metadata_file_path=filename_csv, + interval_start_minutes=-30, + interval_end_minutes=60, + time_resolution_minutes=30, + ) yield site diff --git a/tests/select/test_find_contiguous_time_periods.py b/tests/select/test_find_contiguous_time_periods.py index 9455b82..5024318 100644 --- a/tests/select/test_find_contiguous_time_periods.py +++ b/tests/select/test_find_contiguous_time_periods.py @@ -11,8 +11,8 @@ def test_find_contiguous_t0_periods(): # Create 5-minutely data timestamps freq = pd.Timedelta(5, "min") - history_duration = pd.Timedelta(60, "min") - forecast_duration = pd.Timedelta(15, "min") + interval_start = pd.Timedelta(-60, "min") + interval_end = pd.Timedelta(15, "min") datetimes = ( pd.date_range("2023-01-01 12:00", "2023-01-01 17:00", freq=freq) @@ -21,8 +21,8 @@ def test_find_contiguous_t0_periods(): periods = find_contiguous_t0_periods( datetimes=datetimes, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, sample_period_duration=freq, ) @@ -135,7 +135,7 @@ def test_find_contiguous_t0_periods_nwp(): # Create 3-hourly init times with a few time stamps missing freq = pd.Timedelta(3, "h") - datetimes = ( + init_times = ( pd.date_range("2023-01-01 03:00", "2023-01-02 21:00", freq=freq) .delete([1, 4, 5, 6, 7, 9, 10]) ) @@ -146,13 +146,13 @@ def test_find_contiguous_t0_periods_nwp(): max_dropouts_hr = [0, 0, 0, 0, 3] for i in range(len(expected_results)): - history_duration = pd.Timedelta(history_durations_hr[i], "h") + interval_start = pd.Timedelta(-history_durations_hr[i], "h") max_staleness = pd.Timedelta(max_stalenesses_hr[i], "h") max_dropout = pd.Timedelta(max_dropouts_hr[i], "h") time_periods = find_contiguous_t0_periods_nwp( - datetimes=datetimes, - history_duration=history_duration, + init_times=init_times, + interval_start=interval_start, max_staleness=max_staleness, max_dropout=max_dropout, ) diff --git a/tests/select/test_select_time_slice.py b/tests/select/test_select_time_slice.py index dc01f9b..b4c8d9c 100644 --- a/tests/select/test_select_time_slice.py +++ b/tests/select/test_select_time_slice.py @@ -55,31 +55,19 @@ def test_select_time_slice(da_sat_like, t0_str): # Slice parameters t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - forecast_duration = pd.Timedelta("0min") - history_duration = pd.Timedelta("60min") + interval_start = pd.Timedelta(-0, "min") + interval_end = pd.Timedelta(60, "min") freq = pd.Timedelta("5min") # Expect to return these timestamps from the selection - expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_datetimes = pd.date_range(t0 +interval_start, t0 + interval_end, freq=freq) - # Make the selection using the `[x]_duration` parameters + # Make the selection sat_sample = select_time_slice( - ds=da_sat_like, + da_sat_like, t0=t0, - history_duration=history_duration, - forecast_duration=forecast_duration, - sample_period_duration=freq, - ) - - # Check the returned times are as expected - assert (sat_sample.time_utc == expected_datetimes).all() - - # Make the selection using the `interval_[x]` parameters - sat_sample = select_time_slice( - ds=da_sat_like, - t0=t0, - interval_start=-history_duration, - interval_end=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, sample_period_duration=freq, ) @@ -93,8 +81,8 @@ def test_select_time_slice_out_of_bounds(da_sat_like, t0_str): # Slice parameters t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - forecast_duration = pd.Timedelta("30min") - history_duration = pd.Timedelta("60min") + interval_start = pd.Timedelta(-30, "min") + interval_end = pd.Timedelta(60, "min") freq = pd.Timedelta("5min") # The data is available between these times @@ -102,14 +90,14 @@ def test_select_time_slice_out_of_bounds(da_sat_like, t0_str): max_time = da_sat_like.time_utc.max() # Expect to return these timestamps from the selection - expected_datetimes = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_datetimes = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq) # Make the partially out of bounds selection sat_sample = select_time_slice( - ds=da_sat_like, + da_sat_like, t0=t0, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, sample_period_duration=freq, fill_selection=True ) @@ -138,8 +126,8 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str): # Slice parameters t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - forecast_duration = pd.Timedelta("6h") - history_duration = pd.Timedelta("3h") + interval_start = pd.Timedelta(-6, "h") + interval_end = pd.Timedelta(3, "h") freq = pd.Timedelta("1h") # Make the selection @@ -147,8 +135,8 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str): da_nwp_like, t0, sample_period_duration=freq, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, dropout_timedeltas = None, dropout_frac = 0, accum_channels = [], @@ -156,7 +144,7 @@ def test_select_time_slice_nwp_basic(da_nwp_like, t0_str): ) # Check the target-times are as expected - expected_target_times = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq) assert (da_slice.target_time_utc==expected_target_times).all() # Check the init-times are as expected @@ -172,8 +160,8 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours): """Test the functionality of select_time_slice_nwp with dropout""" t0 = pd.Timestamp("2024-01-02 12:00") - forecast_duration = pd.Timedelta("6h") - history_duration = pd.Timedelta("3h") + interval_start = pd.Timedelta(-6, "h") + interval_end = pd.Timedelta(3, "h") freq = pd.Timedelta("1h") dropout_timedelta = pd.Timedelta(f"-{dropout_hours}h") @@ -181,8 +169,8 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours): da_nwp_like, t0, sample_period_duration=freq, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, dropout_timedeltas = [dropout_timedelta], dropout_frac = 1, accum_channels = [], @@ -190,7 +178,7 @@ def test_select_time_slice_nwp_with_dropout(da_nwp_like, dropout_hours): ) # Check the target-times are as expected - expected_target_times = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq) assert (da_slice.target_time_utc==expected_target_times).all() # Check the init-times are as expected considering the delay @@ -207,9 +195,9 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): # Slice parameters t0 = pd.Timestamp(f"2024-01-02 {t0_str}") - forecast_duration = pd.Timedelta("6h") - history_duration = pd.Timedelta("3h") - freq = pd.Timedelta("1h") + interval_start = pd.Timedelta(-6, "h") + interval_end = pd.Timedelta(3, "h") + freq = pd.Timedelta("1H") dropout_timedelta = pd.Timedelta("-2h") t0_delayed = (t0 + dropout_timedelta).floor(NWP_FREQ) @@ -218,8 +206,8 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): da_nwp_like, t0, sample_period_duration=freq, - history_duration=history_duration, - forecast_duration=forecast_duration, + interval_start=interval_start, + interval_end=interval_end, dropout_timedeltas=[dropout_timedelta], dropout_frac=1, accum_channels=["dswrf"], @@ -227,7 +215,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): ) # Check the target-times are as expected - expected_target_times = pd.date_range(t0 - history_duration, t0 + forecast_duration, freq=freq) + expected_target_times = pd.date_range(t0 + interval_start, t0 + interval_end, freq=freq) assert (da_slice.target_time_utc==expected_target_times).all() # Check the init-times are as expected considering the delay @@ -254,7 +242,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): init_time_utc=t0_delayed, channel="dswrf", ).diff(dim="step", label="lower") - .sel(step=slice(t0-t0_delayed - history_duration, t0-t0_delayed + forecast_duration)) + .sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end)) ) # Check the values are the same @@ -275,7 +263,7 @@ def test_select_time_slice_nwp_with_dropout_and_accum(da_nwp_like, t0_str): init_time_utc=t0_delayed, channel="t", ) - .sel(step=slice(t0-t0_delayed - history_duration, t0-t0_delayed + forecast_duration)) + .sel(step=slice(t0-t0_delayed + interval_start, t0-t0_delayed + interval_end)) ) # Check the values are the same diff --git a/tests/test_data/configs/pvnet_test_config.yaml b/tests/test_data/configs/pvnet_test_config.yaml index f6c2125..01d29c2 100644 --- a/tests/test_data/configs/pvnet_test_config.yaml +++ b/tests/test_data/configs/pvnet_test_config.yaml @@ -6,8 +6,8 @@ input_data: gsp: zarr_path: set_in_temp_file - history_minutes: 60 - forecast_minutes: 120 + interval_start_minutes: -60 + interval_end_minutes: 120 time_resolution_minutes: 30 dropout_timedeltas_minutes: null dropout_fraction: 0 @@ -16,8 +16,8 @@ input_data: ukv: provider: ukv zarr_path: set_in_temp_file - history_minutes: 60 - forecast_minutes: 120 + interval_start_minutes: -60 + interval_end_minutes: 120 time_resolution_minutes: 60 channels: - t # 2-metre temperature @@ -29,9 +29,8 @@ input_data: satellite: zarr_path: set_in_temp_file - history_minutes: 30 - forecast_minutes: 0 - live_delay_minutes: 0 + interval_start_minutes: -30 + interval_end_minutes: 0 time_resolution_minutes: 5 channels: - IR_016 diff --git a/tests/test_data/configs/test_config.yaml b/tests/test_data/configs/test_config.yaml index 6b3ff6d..634fe73 100644 --- a/tests/test_data/configs/test_config.yaml +++ b/tests/test_data/configs/test_config.yaml @@ -1,35 +1,36 @@ general: description: test example configuration name: example + input_data: gsp: zarr_path: tests/data/gsp/test.zarr - history_minutes: 60 - forecast_minutes: 120 + interval_start_minutes: -60 + interval_end_minutes: 120 time_resolution_minutes: 30 dropout_timedeltas_minutes: [-30] dropout_fraction: 0.1 nwp: ukv: + zarr_path: tests/data/nwp_data/test.zarr + provider: "ukv" + interval_start_minutes: -60 + interval_end_minutes: 120 + time_resolution_minutes: 60 channels: - t image_size_pixels_height: 2 image_size_pixels_width: 2 - zarr_path: tests/data/nwp_data/test.zarr - provider: "ukv" - history_minutes: 60 - forecast_minutes: 120 - time_resolution_minutes: 60 dropout_timedeltas_minutes: [-180] dropout_fraction: 1.0 + max_staleness_minutes: null satellite: + zarr_path: tests/data/sat_data.zarr + time_resolution_minutes: 15 + interval_start_minutes: -60 + interval_end_minutes: 0 channels: - IR_016 image_size_pixels_height: 24 image_size_pixels_width: 24 - zarr_path: tests/data/sat_data.zarr - time_resolution_minutes: 15 - history_minutes: 60 - forecast_minutes: 0 - live_delay_minutes: 0 \ No newline at end of file