Skip to content

Commit

Permalink
add config model and corresponding tests (#45)
Browse files Browse the repository at this point in the history
* add config model and corresponding tests

* strip out extras from config model and corresponding tests and test data

* rm extra test config .yaml

* fix incomplete refactoring of test_config

* move pvnet_test_config.yaml to test_data/configs

* add config functions to configs.__init__ and refactor

* remove config readme

* update test_config.yaml

* clean imports and type hints

* adding and amending config tests + small validator fix

* docstrings and added test in config tests

* docstrings and fix dropout test

* added dropout consistency test + docstrings

* update tests

* rm duplicate test from dropout.py

* copy pyproject.toml from main

* assertion in dropout.py made consistent with config test

---------
  • Loading branch information
AUdaltsova authored Oct 4, 2024
1 parent d0d8290 commit a6b0831
Show file tree
Hide file tree
Showing 11 changed files with 520 additions and 10 deletions.
5 changes: 5 additions & 0 deletions ocf_data_sampler/config/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Configuration model"""

from ocf_data_sampler.config.model import Configuration
from ocf_data_sampler.config.save import save_yaml_configuration
from ocf_data_sampler.config.load import load_yaml_configuration
33 changes: 33 additions & 0 deletions ocf_data_sampler/config/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Loading configuration functions.
Example:
from ocf_data_sampler.config import load_yaml_configuration
configuration = load_yaml_configuration(filename)
"""

import fsspec
from pathy import Pathy
from pyaml_env import parse_config

from ocf_data_sampler.config import Configuration


def load_yaml_configuration(filename: str | Pathy) -> Configuration:
"""
Load a yaml file which has a configuration in it
Args:
filename: the file name that you want to load. Will load from local, AWS, or GCP
depending on the protocol suffix (e.g. 's3://bucket/config.yaml').
Returns:pydantic class
"""
# load the file to a dictionary
with fsspec.open(filename, mode="r") as stream:
configuration = parse_config(data=stream)
# this means we can load ENVs in the yaml file
# turn into pydantic class
configuration = Configuration(**configuration)
return configuration
249 changes: 249 additions & 0 deletions ocf_data_sampler/config/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
"""Configuration model for the dataset.
All paths must include the protocol prefix. For local files,
it's sufficient to just start with a '/'. For aws, start with 's3://',
for gcp start with 'gs://'.
Example:
from ocf_data_sampler.config import Configuration
config = Configuration(**config_dict)
"""

import logging
from typing import Dict, List, Optional
from typing_extensions import Self

from pydantic import BaseModel, Field, RootModel, field_validator, ValidationInfo, model_validator
from ocf_datapipes.utils.consts import NWP_PROVIDERS

logger = logging.getLogger(__name__)

providers = ["pvoutput.org", "solar_sheffield_passiv"]


class Base(BaseModel):
"""Pydantic Base model where no extras can be added"""

class Config:
"""config class"""

extra = "forbid" # forbid use of extra kwargs


class General(Base):
"""General pydantic model"""

name: str = Field("example", description="The name of this configuration file.")
description: str = Field(
"example configuration", description="Description of this configuration file"
)


class DataSourceMixin(Base):
"""Mixin class, to add forecast and history minutes"""

forecast_minutes: int = Field(
...,
ge=0,
description="how many minutes to forecast in the future. ",
)
history_minutes: int = Field(
...,
ge=0,
description="how many historic minutes to use. ",
)


# noinspection PyMethodParameters
class DropoutMixin(Base):
"""Mixin class, to add dropout minutes"""

dropout_timedeltas_minutes: Optional[List[int]] = Field(
default=None,
description="List of possible minutes before t0 where data availability may start. Must be "
"negative or zero.",
)

dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample")

@field_validator("dropout_timedeltas_minutes")
def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
"""Validate 'dropout_timedeltas_minutes'"""
if v is not None:
for m in v:
assert m <= 0, "Dropout timedeltas must be negative"
return v

@field_validator("dropout_fraction")
def dropout_fraction_valid(cls, v: float) -> float:
"""Validate 'dropout_fraction'"""
assert 0 <= v <= 1, "Dropout fraction must be between 0 and 1"
return v

@model_validator(mode="after")
def dropout_instructions_consistent(self) -> Self:
if self.dropout_fraction == 0:
if self.dropout_timedeltas_minutes is not None:
raise ValueError("To use dropout timedeltas dropout fraction should be > 0")
else:
if self.dropout_timedeltas_minutes is None:
raise ValueError("To dropout fraction > 0 requires a list of dropout timedeltas")
return self


# noinspection PyMethodParameters
class TimeResolutionMixin(Base):
"""Time resolution mix in"""

time_resolution_minutes: int = Field(
...,
description="The temporal resolution of the data in minutes",
)


class Satellite(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
"""Satellite configuration model"""

# Todo: remove 'satellite' from names
satellite_zarr_path: str | tuple[str] | list[str] = Field(
...,
description="The path or list of paths which hold the satellite zarr",
)
satellite_channels: list[str] = Field(
..., description="the satellite channels that are used"
)
satellite_image_size_pixels_height: int = Field(
...,
description="The number of pixels of the height of the region of interest"
" for non-HRV satellite channels.",
)

satellite_image_size_pixels_width: int = Field(
...,
description="The number of pixels of the width of the region "
"of interest for non-HRV satellite channels.",
)

live_delay_minutes: int = Field(
..., description="The expected delay in minutes of the satellite data"
)


# noinspection PyMethodParameters
class NWP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
"""NWP configuration model"""

nwp_zarr_path: str | tuple[str] | list[str] = Field(
...,
description="The path which holds the NWP zarr",
)
nwp_channels: list[str] = Field(
..., description="the channels used in the nwp data"
)
nwp_accum_channels: list[str] = Field([], description="the nwp channels which need to be diffed")
nwp_image_size_pixels_height: int = Field(..., description="The size of NWP spacial crop in pixels")
nwp_image_size_pixels_width: int = Field(..., description="The size of NWP spacial crop in pixels")

nwp_provider: str = Field(..., description="The provider of the NWP data")

max_staleness_minutes: Optional[int] = Field(
None,
description="Sets a limit on how stale an NWP init time is allowed to be whilst still being"
" used to construct an example. If set to None, then the max staleness is set according to"
" the maximum forecast horizon of the NWP and the requested forecast length.",
)


@field_validator("nwp_provider")
def validate_nwp_provider(cls, v: str) -> str:
"""Validate 'nwp_provider'"""
if v.lower() not in NWP_PROVIDERS:
message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
logger.warning(message)
raise Exception(message)
return v

# Todo: put into time mixin when moving intervals there
@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class MultiNWP(RootModel):
"""Configuration for multiple NWPs"""

root: Dict[str, NWP]

def __getattr__(self, item):
return self.root[item]

def __getitem__(self, item):
return self.root[item]

def __len__(self):
return len(self.root)

def __iter__(self):
return iter(self.root)

def keys(self):
"""Returns dictionary-like keys"""
return self.root.keys()

def items(self):
"""Returns dictionary-like items"""
return self.root.items()


# noinspection PyMethodParameters
class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
"""GSP configuration model"""

gsp_zarr_path: str = Field(..., description="The path which holds the GSP zarr")

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


# noinspection PyPep8Naming
class InputData(Base):
"""
Input data model.
"""

satellite: Optional[Satellite] = None
nwp: Optional[MultiNWP] = None
gsp: Optional[GSP] = None


class Configuration(Base):
"""Configuration model for the dataset"""

general: General = General()
input_data: InputData = InputData()
36 changes: 36 additions & 0 deletions ocf_data_sampler/config/save.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Save functions for the configuration model.
Example:
from ocf_data_sampler.config import save_yaml_configuration
configuration = save_yaml_configuration(config, filename)
"""

import json

import fsspec
import yaml
from pathy import Pathy

from ocf_data_sampler.config import Configuration


def save_yaml_configuration(
configuration: Configuration, filename: str | Pathy
):
"""
Save a local yaml file which has the configuration in it.
If `filename` is None then saves to configuration.output_data.filepath / configuration.yaml.
Will save to GCP, AWS, or local, depending on the protocol suffix of filepath.
"""
# make a dictionary from the configuration,
# Note that we make the object json'able first, so that it can be saved to a yaml file
d = json.loads(configuration.model_dump_json())
if filename is None:
filename = Pathy(configuration.output_data.filepath) / "configuration.yaml"

# save to a yaml file
with fsspec.open(filename, "w") as yaml_file:
yaml.safe_dump(d, yaml_file, default_flow_style=False)
4 changes: 2 additions & 2 deletions ocf_data_sampler/select/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def draw_dropout_time(
if dropout_timedeltas is not None:
assert len(dropout_timedeltas) >= 1, "Must include list of relative dropout timedeltas"
assert all(
[t < pd.Timedelta("0min") for t in dropout_timedeltas]
[t <= pd.Timedelta("0min") for t in dropout_timedeltas]
), "dropout timedeltas must be negative"
assert 0 <= dropout_frac <= 1

Expand All @@ -35,4 +35,4 @@ def apply_dropout_time(
return ds
else:
# This replaces the times after the dropout with NaNs
return ds.where(ds.time_utc <= dropout_time)
return ds.where(ds.time_utc <= dropout_time)
3 changes: 1 addition & 2 deletions ocf_data_sampler/torch_datasets/pvnet_uk_regional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
)


from ocf_datapipes.config.model import Configuration
from ocf_datapipes.config.load import load_yaml_configuration
from ocf_data_sampler.config import Configuration, load_yaml_configuration
from ocf_datapipes.batch import BatchKey, NumpyBatch

from ocf_datapipes.utils.location import Location
Expand Down
Loading

0 comments on commit a6b0831

Please sign in to comment.