diff --git a/ocf_data_sampler/config/__init__.py b/ocf_data_sampler/config/__init__.py new file mode 100644 index 0000000..ccfb3fe --- /dev/null +++ b/ocf_data_sampler/config/__init__.py @@ -0,0 +1,5 @@ +"""Configuration model""" + +from ocf_data_sampler.config.model import Configuration +from ocf_data_sampler.config.save import save_yaml_configuration +from ocf_data_sampler.config.load import load_yaml_configuration \ No newline at end of file diff --git a/ocf_data_sampler/config/load.py b/ocf_data_sampler/config/load.py new file mode 100644 index 0000000..e630809 --- /dev/null +++ b/ocf_data_sampler/config/load.py @@ -0,0 +1,33 @@ +"""Loading configuration functions. + +Example: + + from ocf_data_sampler.config import load_yaml_configuration + configuration = load_yaml_configuration(filename) +""" + +import fsspec +from pathy import Pathy +from pyaml_env import parse_config + +from ocf_data_sampler.config import Configuration + + +def load_yaml_configuration(filename: str | Pathy) -> Configuration: + """ + Load a yaml file which has a configuration in it + + Args: + filename: the file name that you want to load. Will load from local, AWS, or GCP + depending on the protocol suffix (e.g. 's3://bucket/config.yaml'). + + Returns:pydantic class + + """ + # load the file to a dictionary + with fsspec.open(filename, mode="r") as stream: + configuration = parse_config(data=stream) + # this means we can load ENVs in the yaml file + # turn into pydantic class + configuration = Configuration(**configuration) + return configuration diff --git a/ocf_data_sampler/config/model.py b/ocf_data_sampler/config/model.py new file mode 100644 index 0000000..34ca90c --- /dev/null +++ b/ocf_data_sampler/config/model.py @@ -0,0 +1,249 @@ +"""Configuration model for the dataset. + +All paths must include the protocol prefix. For local files, +it's sufficient to just start with a '/'. For aws, start with 's3://', +for gcp start with 'gs://'. + +Example: + + from ocf_data_sampler.config import Configuration + config = Configuration(**config_dict) +""" + +import logging +from typing import Dict, List, Optional +from typing_extensions import Self + +from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator +from ocf_datapipes.utils.consts import NWP_PROVIDERS + +logger = logging.getLogger(__name__) + +providers = ["pvoutput.org", "solar_sheffield_passiv"] + + +class Base(BaseModel): + """Pydantic Base model where no extras can be added""" + + class Config: + """config class""" + + extra = "forbid" # forbid use of extra kwargs + + +class General(Base): + """General pydantic model""" + + name: str = Field("example", description="The name of this configuration file.") + description: str = Field( + "example configuration", description="Description of this configuration file" + ) + + +class DataSourceMixin(Base): + """Mixin class, to add forecast and history minutes""" + + forecast_minutes: int = Field( + ..., + ge=0, + description="how many minutes to forecast in the future. ", + ) + history_minutes: int = Field( + ..., + ge=0, + description="how many historic minutes to use. ", + ) + + +# noinspection PyMethodParameters +class DropoutMixin(Base): + """Mixin class, to add dropout minutes""" + + dropout_timedeltas_minutes: Optional[List[int]] = Field( + default=None, + description="List of possible minutes before t0 where data availability may start. Must be " + "negative or zero.", + ) + + dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample") + + @field_validator("dropout_timedeltas_minutes") + def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]: + """Validate 'dropout_timedeltas_minutes'""" + if v is not None: + for m in v: + assert m <= 0, "Dropout timedeltas must be negative" + return v + + @field_validator("dropout_fraction") + def dropout_fraction_valid(cls, v: float) -> float: + """Validate 'dropout_fraction'""" + assert 0 <= v <= 1, "Dropout fraction must be between 0 and 1" + return v + + @model_validator(mode="after") + def dropout_instructions_consistent(self) -> Self: + if self.dropout_fraction == 0: + if self.dropout_timedeltas_minutes is not None: + raise ValueError("To use dropout timedeltas dropout fraction should be > 0") + else: + if self.dropout_timedeltas_minutes is None: + raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas") + return self + + +# noinspection PyMethodParameters +class TimeResolutionMixin(Base): + """Time resolution mix in""" + + time_resolution_minutes: int = Field( + ..., + description="The temporal resolution of the data in minutes", + ) + + +class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin): + """Satellite configuration model""" + + # Todo: remove 'satellite' from names + satellite_zarr_path: str | tuple[str] | list[str] = Field( + ..., + description="The path or list of paths which hold the satellite zarr", + ) + satellite_channels: list[str] = Field( + ..., description="the satellite channels that are used" + ) + satellite_image_size_pixels_height: int = Field( + ..., + description="The number of pixels of the height of the region of interest" + " for non-HRV satellite channels.", + ) + + satellite_image_size_pixels_width: int = Field( + ..., + description="The number of pixels of the width of the region " + "of interest for non-HRV satellite channels.", + ) + + live_delay_minutes: int = Field( + ..., description="The expected delay in minutes of the satellite data" + ) + + +# noinspection PyMethodParameters +class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin): + """NWP configuration model""" + + nwp_zarr_path: str | tuple[str] | list[str] = Field( + ..., + description="The path which holds the NWP zarr", + ) + nwp_channels: list[str] = Field( + ..., description="the channels used in the nwp data" + ) + nwp_accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed") + nwp_image_size_pixels_height: int = Field(..., description="The size of NWP spacial crop in pixels") + nwp_image_size_pixels_width: int = Field(..., description="The size of NWP spacial crop in pixels") + + nwp_provider: str = Field(..., description="The provider of the NWP data") + + max_staleness_minutes: Optional[int] = Field( + None, + description="Sets a limit on how stale an NWP init time is allowed to be whilst still being" + " used to construct an example. If set to None, then the max staleness is set according to" + " the maximum forecast horizon of the NWP and the requested forecast length.", + ) + + + @field_validator("nwp_provider") + def validate_nwp_provider(cls, v: str) -> str: + """Validate 'nwp_provider'""" + if v.lower() not in NWP_PROVIDERS: + message = f"NWP provider {v} is not in {NWP_PROVIDERS}" + logger.warning(message) + raise Exception(message) + return v + + # Todo: put into time mixin when moving intervals there + @field_validator("forecast_minutes") + def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + if v % info.data["time_resolution_minutes"] != 0: + message = "Forecast duration must be divisible by time resolution" + logger.error(message) + raise Exception(message) + return v + + @field_validator("history_minutes") + def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + if v % info.data["time_resolution_minutes"] != 0: + message = "History duration must be divisible by time resolution" + logger.error(message) + raise Exception(message) + return v + + +class MultiNWP(RootModel): + """Configuration for multiple NWPs""" + + root: Dict[str, NWP] + + def __getattr__(self, item): + return self.root[item] + + def __getitem__(self, item): + return self.root[item] + + def __len__(self): + return len(self.root) + + def __iter__(self): + return iter(self.root) + + def keys(self): + """Returns dictionary-like keys""" + return self.root.keys() + + def items(self): + """Returns dictionary-like items""" + return self.root.items() + + +# noinspection PyMethodParameters +class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin): + """GSP configuration model""" + + gsp_zarr_path: str = Field(..., description="The path which holds the GSP zarr") + + @field_validator("forecast_minutes") + def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + if v % info.data["time_resolution_minutes"] != 0: + message = "Forecast duration must be divisible by time resolution" + logger.error(message) + raise Exception(message) + return v + + @field_validator("history_minutes") + def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int: + if v % info.data["time_resolution_minutes"] != 0: + message = "History duration must be divisible by time resolution" + logger.error(message) + raise Exception(message) + return v + + +# noinspection PyPep8Naming +class InputData(Base): + """ + Input data model. + """ + + satellite: Optional[Satellite] = None + nwp: Optional[MultiNWP] = None + gsp: Optional[GSP] = None + + +class Configuration(Base): + """Configuration model for the dataset""" + + general: General = General() + input_data: InputData = InputData() \ No newline at end of file diff --git a/ocf_data_sampler/config/save.py b/ocf_data_sampler/config/save.py new file mode 100644 index 0000000..fe8faea --- /dev/null +++ b/ocf_data_sampler/config/save.py @@ -0,0 +1,36 @@ +"""Save functions for the configuration model. + +Example: + + from ocf_data_sampler.config import save_yaml_configuration + configuration = save_yaml_configuration(config, filename) +""" + +import json + +import fsspec +import yaml +from pathy import Pathy + +from ocf_data_sampler.config import Configuration + + +def save_yaml_configuration( + configuration: Configuration, filename: str | Pathy +): + """ + Save a local yaml file which has the configuration in it. + + If `filename` is None then saves to configuration.output_data.filepath / configuration.yaml. + + Will save to GCP, AWS, or local, depending on the protocol suffix of filepath. + """ + # make a dictionary from the configuration, + # Note that we make the object json'able first, so that it can be saved to a yaml file + d = json.loads(configuration.model_dump_json()) + if filename is None: + filename = Pathy(configuration.output_data.filepath) / "configuration.yaml" + + # save to a yaml file + with fsspec.open(filename, "w") as yaml_file: + yaml.safe_dump(d, yaml_file, default_flow_style=False) diff --git a/ocf_data_sampler/select/dropout.py b/ocf_data_sampler/select/dropout.py index 5e45fee..2405546 100644 --- a/ocf_data_sampler/select/dropout.py +++ b/ocf_data_sampler/select/dropout.py @@ -12,7 +12,7 @@ def draw_dropout_time( if dropout_timedeltas is not None: assert len(dropout_timedeltas) >= 1, "Must include list of relative dropout timedeltas" assert all( - [t < pd.Timedelta("0min") for t in dropout_timedeltas] + [t <= pd.Timedelta("0min") for t in dropout_timedeltas] ), "dropout timedeltas must be negative" assert 0 <= dropout_frac <= 1 @@ -35,4 +35,4 @@ def apply_dropout_time( return ds else: # This replaces the times after the dropout with NaNs - return ds.where(ds.time_utc <= dropout_time) \ No newline at end of file + return ds.where(ds.time_utc <= dropout_time) diff --git a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py index b4020fe..ecd952a 100644 --- a/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py +++ b/ocf_data_sampler/torch_datasets/pvnet_uk_regional.py @@ -27,8 +27,7 @@ ) -from ocf_datapipes.config.model import Configuration -from ocf_datapipes.config.load import load_yaml_configuration +from ocf_data_sampler.config import Configuration, load_yaml_configuration from ocf_datapipes.batch import BatchKey, NumpyBatch from ocf_datapipes.utils.location import Location diff --git a/tests/config/test_config.py b/tests/config/test_config.py new file mode 100644 index 0000000..df5d418 --- /dev/null +++ b/tests/config/test_config.py @@ -0,0 +1,152 @@ +import tempfile + +import pytest +from pydantic import ValidationError + +from ocf_data_sampler.config import ( + load_yaml_configuration, + Configuration, + save_yaml_configuration +) + + +def test_default(): + """Test default pydantic class""" + + _ = Configuration() + + +def test_yaml_load_test_config(test_config_filename): + """ + Test that yaml loading works for 'test_config.yaml' + and fails for an empty .yaml file + """ + + # check we get an error if loading a file with no config + with tempfile.NamedTemporaryFile(suffix=".yaml") as fp: + filename = fp.name + + # check that temp file can't be loaded + with pytest.raises(TypeError): + _ = load_yaml_configuration(filename) + + # test can load test_config.yaml + config = load_yaml_configuration(test_config_filename) + + assert isinstance(config, Configuration) + + +def test_yaml_save(test_config_filename): + """ + Check configuration can be saved to a .yaml file + """ + + test_config = load_yaml_configuration(test_config_filename) + + with tempfile.NamedTemporaryFile(suffix=".yaml") as fp: + filename = fp.name + + # save default config to file + save_yaml_configuration(test_config, filename) + + # check the file can be loaded back + tmp_config = load_yaml_configuration(filename) + + # check loaded configuration is the same as the one passed to save + assert test_config == tmp_config + + +def test_extra_field(): + """ + Check an extra parameters in config causes error + """ + + configuration = Configuration() + configuration_dict = configuration.model_dump() + configuration_dict["extra_field"] = "extra_value" + with pytest.raises(ValidationError, match="Extra inputs are not permitted"): + _ = Configuration(**configuration_dict) + + +def test_incorrect_forecast_minutes(test_config_filename): + """ + Check a forecast length not divisible by time resolution causes error + """ + + configuration = load_yaml_configuration(test_config_filename) + + configuration.input_data.nwp['ukv'].forecast_minutes = 1111 + with pytest.raises(Exception, match="duration must be divisible by time resolution"): + _ = Configuration(**configuration.model_dump()) + + +def test_incorrect_history_minutes(test_config_filename): + """ + Check a history length not divisible by time resolution causes error + """ + + configuration = load_yaml_configuration(test_config_filename) + + configuration.input_data.nwp['ukv'].history_minutes = 1111 + with pytest.raises(Exception, match="duration must be divisible by time resolution"): + _ = Configuration(**configuration.model_dump()) + + +def test_incorrect_nwp_provider(test_config_filename): + """ + Check an unexpected nwp provider causes error + """ + + configuration = load_yaml_configuration(test_config_filename) + + configuration.input_data.nwp['ukv'].nwp_provider = "unexpected_provider" + with pytest.raises(Exception, match="NWP provider"): + _ = Configuration(**configuration.model_dump()) + +def test_incorrect_dropout(test_config_filename): + """ + Check a dropout timedelta over 0 causes error and 0 doesn't + """ + + configuration = load_yaml_configuration(test_config_filename) + + # check a positive number is not allowed + configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [120] + with pytest.raises(Exception, match="Dropout timedeltas must be negative"): + _ = Configuration(**configuration.model_dump()) + + # check 0 is allowed + configuration.input_data.nwp['ukv'].dropout_timedeltas_minutes = [0] + _ = Configuration(**configuration.model_dump()) + +def test_incorrect_dropout_fraction(test_config_filename): + """ + Check dropout fraction outside of range causes error + """ + + configuration = load_yaml_configuration(test_config_filename) + + configuration.input_data.nwp['ukv'].dropout_fraction= 1.1 + with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"): + _ = Configuration(**configuration.model_dump()) + + configuration.input_data.nwp['ukv'].dropout_fraction= -0.1 + with pytest.raises(Exception, match="Dropout fraction must be between 0 and 1"): + _ = Configuration(**configuration.model_dump()) + + +def test_inconsistent_dropout_use(test_config_filename): + """ + Check dropout fraction outside of range causes error + """ + + configuration = load_yaml_configuration(test_config_filename) + configuration.input_data.satellite.dropout_fraction= 1.0 + configuration.input_data.satellite.dropout_timedeltas_minutes = None + + with pytest.raises(ValueError, match="To dropout fraction > 0 requires a list of dropout timedeltas"): + _ = Configuration(**configuration.model_dump()) + configuration.input_data.satellite.dropout_fraction= 0.0 + configuration.input_data.satellite.dropout_timedeltas_minutes = [-120, -60] + with pytest.raises(ValueError, match="To use dropout timedeltas dropout fraction should be > 0"): + _ = Configuration(**configuration.model_dump()) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 81a2072..b92b801 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,11 +6,16 @@ import xarray as xr import tempfile +_top_test_directory = os.path.dirname(os.path.realpath(__file__)) + +@pytest.fixture() +def test_config_filename(): + return f"{_top_test_directory}/test_data/configs/test_config.yaml" @pytest.fixture(scope="session") def config_filename(): - return f"{os.path.dirname(os.path.abspath(__file__))}/test_data/pvnet_test_config.yaml" + return f"{os.path.dirname(os.path.abspath(__file__))}/test_data/configs/pvnet_test_config.yaml" @pytest.fixture(scope="session") diff --git a/tests/test_data/pvnet_test_config.yaml b/tests/test_data/configs/pvnet_test_config.yaml similarity index 94% rename from tests/test_data/pvnet_test_config.yaml rename to tests/test_data/configs/pvnet_test_config.yaml index e32d6f6..1625ace 100644 --- a/tests/test_data/pvnet_test_config.yaml +++ b/tests/test_data/configs/pvnet_test_config.yaml @@ -3,8 +3,6 @@ general: name: pvnet_test input_data: - default_history_minutes: 60 - default_forecast_minutes: 120 gsp: gsp_zarr_path: set_in_temp_file diff --git a/tests/test_data/configs/test_config.yaml b/tests/test_data/configs/test_config.yaml new file mode 100644 index 0000000..024f2a3 --- /dev/null +++ b/tests/test_data/configs/test_config.yaml @@ -0,0 +1,35 @@ +general: + description: test example configuration + name: example +input_data: + gsp: + gsp_zarr_path: tests/data/gsp/test.zarr + history_minutes: 60 + forecast_minutes: 120 + time_resolution_minutes: 30 + dropout_timedeltas_minutes: [-30] + dropout_fraction: 0.1 + nwp: + ukv: + nwp_channels: + - t + nwp_image_size_pixels_height: 2 + nwp_image_size_pixels_width: 2 + nwp_zarr_path: tests/data/nwp_data/test.zarr + nwp_provider: "ukv" + history_minutes: 60 + forecast_minutes: 120 + time_resolution_minutes: 60 + dropout_timedeltas_minutes: [-180] + dropout_fraction: 1.0 + + satellite: + satellite_channels: + - IR_016 + satellite_image_size_pixels_height: 24 + satellite_image_size_pixels_width: 24 + satellite_zarr_path: tests/data/sat_data.zarr + time_resolution_minutes: 15 + history_minutes: 60 + forecast_minutes: 0 + live_delay_minutes: 0 \ No newline at end of file diff --git a/tests/torch_datasets/test_pvnet_uk_regional.py b/tests/torch_datasets/test_pvnet_uk_regional.py index 1e4ac88..490c7b5 100644 --- a/tests/torch_datasets/test_pvnet_uk_regional.py +++ b/tests/torch_datasets/test_pvnet_uk_regional.py @@ -2,8 +2,7 @@ import tempfile from ocf_data_sampler.torch_datasets.pvnet_uk_regional import PVNetUKRegionalDataset -from ocf_datapipes.config.load import load_yaml_configuration -from ocf_datapipes.config.save import save_yaml_configuration +from ocf_data_sampler.config import load_yaml_configuration, save_yaml_configuration from ocf_datapipes.batch import BatchKey, NWPBatchKey @@ -56,7 +55,6 @@ def test_pvnet(pvnet_config_filename): assert sample[BatchKey.gsp_solar_azimuth].shape == (7,) assert sample[BatchKey.gsp_solar_elevation].shape == (7,) - def test_pvnet_no_gsp(pvnet_config_filename): # load config