diff --git a/Parser_PoC/parse_preview.py b/Parser_PoC/parse_preview.py index 7f2683f..65f3e89 100755 --- a/Parser_PoC/parse_preview.py +++ b/Parser_PoC/parse_preview.py @@ -29,7 +29,7 @@ def __init__(self, name, run_spec): self.run_spec = run_spec self.input = [] self.output = [] - self.depends = [] + self.wait_on = [] class WcData(): @@ -279,12 +279,12 @@ def _add_edges_from_cycle(self, cycle): if not out_node.is_concrete(): cluster.append(out_node) # add dependencies - for dep_spec in task_graph_spec.get('depends', []): + for dep_spec in task_graph_spec.get('wait_on', []): if (dep_node := self.get_task(dep_spec)): if isinstance(dep_node, list): - task_node.depends.extend(dep_node) + task_node.wait_on.extend(dep_node) else: - task_node.depends.append(dep_node) + task_node.wait_on.append(dep_node) self.add_edge(dep_node, task_node) # Add clsuter d1 = self.cycling_date diff --git a/Parser_PoC/test_config.yaml b/Parser_PoC/test_config.yaml index be27813..93cb5ec 100644 --- a/Parser_PoC/test_config.yaml +++ b/Parser_PoC/test_config.yaml @@ -23,7 +23,7 @@ scheduling: - ERA5 output: - icon input - depends: + wait_on: - ICON: lag: '-P4M' # lag: '-P6M' diff --git a/src/sirocco/core.py b/src/sirocco/core.py index 9d73f28..fd73bef 100644 --- a/src/sirocco/core.py +++ b/src/sirocco/core.py @@ -6,8 +6,8 @@ from sirocco.parsing._yaml_data_models import ( ConfigCycleTask, - ConfigCycleTaskDepend, ConfigCycleTaskInput, + ConfigCycleTaskWaitOn, ConfigTask, ConfigWorkflow, load_workflow_config, @@ -19,13 +19,13 @@ from sirocco.parsing._yaml_data_models import ConfigCycle, DataBaseModel - type ConfigCycleSpec = ConfigCycleTaskDepend | ConfigCycleTaskInput + type ConfigCycleSpec = ConfigCycleTaskWaitOn | ConfigCycleTaskInput logging.basicConfig() logger = logging.getLogger(__name__) -TimeSeriesObject = TypeVar("TimeSeriesObject") +StoreObject = TypeVar("StoreObject") class BaseNode: @@ -69,7 +69,7 @@ def from_config( inputs: list[Data] = [] for input_spec in task_ref.inputs: inputs.extend(data for data in workflow.data.iter_from_cycle_spec(input_spec, date) if data is not None) - outputs: list[Data] = [workflow.data[output_spec.name, date] for output_spec in task_ref.outputs] + outputs: list[Data] = [workflow.data[output_spec.name, {"date": date}] for output_spec in task_ref.outputs] new = cls( date=date, @@ -80,7 +80,7 @@ def from_config( ) # this works because dataclass has generated this init for us # Store for actual linking in link_wait_on_tasks() once all tasks are created - new._wait_on_specs = task_ref.depends # noqa: SLF001 we don't have access to self in a dataclass + new._wait_on_specs = task_ref.wait_on # noqa: SLF001 we don't have access to self in a dataclass # and setting an underscored attribute from # the class itself raises SLF001 @@ -126,121 +126,161 @@ class Cycle(BaseNode): date: datetime | None = None -class TimeSeries(Generic[TimeSeriesObject]): - """Dictionnary of objects accessed by date, checking start and end dates""" - - def __init__(self) -> None: - self.start_date: datetime | None = None - self.end_date: datetime | None = None - self._dict: dict[str:TimeSeriesObject] = {} - - def __setitem__(self, date: datetime, data: TimeSeriesObject) -> None: - if date in self._dict: - msg = f"date {date} already used, cannot set twice" +class ParamSeries(Generic[StoreObject]): + """Dictionnary of objects accessed by arbitrary parameters""" + + def __init__(self, name: str) -> None: + self._name = name + self._dims: set | None = None + self._axes: dict | None = None + self._dict: dict[tuple:StoreObject] | None = None + + def __setitem__(self, parameters: dict, value: StoreObject) -> None: + # First access: set axes and initialize dictionnary + param_keys = set(parameters.keys()) + if self._dims is None: + self._dims = param_keys + self._axes = {k: set() for k in param_keys} + self._dict = {} + # check dimensions + elif self._dims != param_keys: + msg = ( + f"ParamSeries {self._name}: parameter keys {param_keys} don't match ParamSeries dimensions {self._dims}" + ) raise KeyError(msg) - self._dict[date] = data - if self.start_date is None: - self.start_date = date - self.end_date = date - elif date < self.start_date: - self.start_date = date - elif date > self.end_date: - self.end_date = date - - def __getitem__(self, date: datetime) -> TimeSeriesObject: - if self.start_date is None: - msg = "TimeSeries still empty, cannot access by date" - raise ValueError(msg) - if date < self.start_date or date > self.end_date: - item = next(iter(self._dict.values())) + # Build internal key + # use the order of self._dims instead of param_keys to ensure reproducibility + key = tuple(parameters[dim] for dim in self._dims) + # Check if slot already taken + if key in self._dict: + msg = f"ParamSeries {self._name}: key {key} already used, cannot set item twice" + raise KeyError(msg) + # Store new axes values + for dim in self._dims: + self._axes[dim].add(parameters[dim]) + # Set item + self._dict[key] = value + + def __getitem__(self, parameters: dict) -> StoreObject: + if self._dims != (param_keys := set(parameters.keys())): msg = ( - f"date {date} for item '{item.name}' is out of bounds [{self.start_date} - {self.end_date}], ignoring." + f"ParamSeries {self._name}: parameter keys {param_keys} don't match ParamSeries dimensions {self._dims}" ) - logger.warning(msg) - return - if date not in self._dict: - item = next(iter(self._dict.values())) - msg = f"date {date} for item '{item.name}' not found" raise KeyError(msg) - return self._dict[date] + # use the order of self._dims instead of param_keys to ensure reproducibility + key = tuple(parameters[dim] for dim in self._dims) + return self._dict[key] + + def iter_from_cycle_spec(self, spec: ConfigCycleSpec, ref_date: datetime | None = None) -> Iterator[StoreObject]: + # Check date references + if "date" not in self._dims and (spec.lag or spec.date): + msg = f"ParamSeries {self._name} has no date dimension, cannot be referenced by dates" + raise ValueError(msg) + if "date" in self._dims and ref_date is None and spec.date is []: + msg = f"ParamSeries {self._name} has a date dimension, must be referenced by dates" + raise ValueError(msg) + # Generate list of target item keys + keys = [()] + for dim in self._dims: + if dim == "date": + keys = [(*key, date) for key in keys for date in self._resolve_target_dates(spec, ref_date)] + elif dim in spec.parameters: + keys = [(*key, spec.parameters[dim]) for key in keys] + else: + keys = [(*key, item) for key in keys for item in self._axes[dim]] + # Yield items + for key in keys: + yield self._dict[key] - def values(self) -> Iterator[TimeSeriesObject]: + @staticmethod + def _resolve_target_dates(spec: ConfigCycleSpec, ref_date: datetime | None) -> Iterator[datetime]: + if not spec.lag and not spec.date: + yield ref_date + if spec.lag: + for lag in spec.lag: + yield ref_date + lag + if spec.date: + yield from spec.date + + def values(self) -> Iterator[StoreObject]: yield from self._dict.values() -class Store(Generic[TimeSeriesObject]): - """Container for TimeSeries or unique data""" +class Store(Generic[StoreObject]): + """Container for ParamSeries or unique items""" def __init__(self): - self._dict: dict[str, TimeSeries | TimeSeriesObject] = {} + self._dict: dict[str, ParamSeries | StoreObject] = {} - def __setitem__(self, key: str | tuple(str, datetime | None), value: TimeSeriesObject) -> None: + def __setitem__(self, key: str | tuple(str, dict | None), value: StoreObject) -> None: if isinstance(key, tuple): - name, date = key + name, parameters = key + if "date" in parameters and parameters["date"] is None: + del parameters["date"] else: - name, date = key, None + name, parameters = key, None if name in self._dict: - if not isinstance(self._dict[name], TimeSeries): + if not isinstance(self._dict[name], ParamSeries): msg = f"single entry {name} already set" raise KeyError(msg) - if date is None: - msg = f"entry {name} is a TimeSeries, must be accessed by date" + if parameters is None: + msg = f"entry {name} is a ParamSeries, must be accessed by parameters" raise KeyError(msg) - self._dict[name][date] = value - elif date is None: + self._dict[name][parameters] = value + elif not parameters: self._dict[name] = value else: - self._dict[name] = TimeSeries() - self._dict[name][date] = value + self._dict[name] = ParamSeries(name) + self._dict[name][parameters] = value - def __getitem__(self, key: str | tuple(str, datetime | None)) -> TimeSeriesObject: + def __getitem__(self, key: str | tuple(str, dict | None)) -> StoreObject: if isinstance(key, tuple): - name, date = key + name, parameters = key + if "date" in parameters and parameters["date"] is None: + del parameters["date"] else: - name, date = key, None - + name, parameters = key, None if name not in self._dict: msg = f"entry {name} not found in Store" raise KeyError(msg) - if isinstance(self._dict[name], TimeSeries): - if date is None: - msg = f"entry {name} is a TimeSeries, must be accessed by date" + if isinstance(self._dict[name], ParamSeries): + if parameters is None: + msg = f"entry {name} is a ParamSeries, must be accessed by parameters" raise KeyError(msg) - return self._dict[name][date] - if date is not None: - msg = f"entry {name} is not a TimeSeries, cannot be accessed by date" + return self._dict[name][parameters] + if parameters: + msg = f"entry {name} is not a ParamSeries, cannot be accessed by parameters" raise KeyError(msg) return self._dict[name] - @staticmethod - def _resolve_target_dates(spec, ref_date: datetime | None) -> Iterator[datetime]: - if not spec.lag and not spec.date: - yield ref_date - if spec.lag: - for lag in spec.lag: - yield ref_date + lag - if spec.date: - yield from spec.date - - def iter_from_cycle_spec( - self, spec: ConfigCycleSpec, ref_date: datetime | None = None - ) -> Iterator[TimeSeriesObject]: - name = spec.name - if isinstance(self._dict[name], TimeSeries): - if ref_date is None and spec.date is []: - msg = "TimeSeries object must be referenced by dates" + def iter_from_cycle_spec(self, spec: ConfigCycleSpec, ref_date: datetime | None = None) -> Iterator[StoreObject]: + # Check if target items should be querried at all + if (when := spec.when) is not None: + if ref_date is None: + msg = "Cannot use a `when` specification without a `ref_date`" raise ValueError(msg) - for target_date in self._resolve_target_dates(spec, ref_date): - yield self._dict[name][target_date] + if (at := when.at) is not None and at != ref_date: + return + if (before := when.before) is not None and before <= ref_date: + return + if (after := when.after) is not None and after >= ref_date: + return + # Yield items + name = spec.name + if isinstance(self._dict[name], ParamSeries): + yield from self._dict[name].iter_from_cycle_spec(spec, ref_date) else: if spec.lag or spec.date: - msg = f"item {name} is not a TimeSeries, cannot be referenced via date or lag" + msg = f"item {name} is not a ParamSeries, cannot be referenced by date or lag" + raise ValueError(msg) + if spec.parameters: + msg = f"item {name} is not a ParamSeries, cannot be referenced by parameters" raise ValueError(msg) yield self._dict[name] - def values(self) -> Iterator[TimeSeriesObject]: + def values(self) -> Iterator[StoreObject]: for item in self._dict.values(): - if isinstance(item, TimeSeries): + if isinstance(item, ParamSeries): yield from item.values() else: yield item @@ -265,7 +305,7 @@ def __init__(self, workflow_config: ConfigWorkflow) -> None: for data_ref in task_ref.outputs: data_name = data_ref.name data_config = workflow_config.data_dict[data_name] - self.data[data_name, date] = Data.from_config(data_config, date=date) + self.data[data_name, {"date": date}] = Data.from_config(data_config, date=date) # 3 - create cycles and tasks for cycle_config in workflow_config.cycles: @@ -275,11 +315,12 @@ def __init__(self, workflow_config: ConfigWorkflow) -> None: for task_ref in cycle_config.tasks: task_name = task_ref.name task_config = workflow_config.task_dict[task_name] - self.tasks[task_name, date] = ( + task = Task.from_config(task_config, task_ref, workflow=self, date=date) + self.tasks[task_name, {"date": date}] = ( task := Task.from_config(task_config, task_ref, workflow=self, date=date) ) cycle_tasks.append(task) - self.cycles[cycle_name, date] = Cycle(name=cycle_name, tasks=cycle_tasks, date=date) + self.cycles[cycle_name, {"date": date}] = Cycle(name=cycle_name, tasks=cycle_tasks, date=date) # 4 - Link wait on tasks for task in self.tasks.values(): diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index 9bfef60..c4aff99 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -44,12 +44,42 @@ def __init__(self, /, **data): super().__init__(**name_and_spec) +class _WhenBaseModel(BaseModel): + """Base class for when specifications""" + + before: datetime | None = None + after: datetime | None = None + at: datetime | None = None + + @model_validator(mode="before") + @classmethod + def check_before_after_at_combination(cls, data: Any) -> Any: + if "at" in data and any(k in data for k in ("before", "after")): + msg = "'at' key is incompatible with 'before' and after'" + raise ValueError(msg) + if not any(k in data for k in ("at", "before", "after")): + msg = "use at least one of 'at', 'before' or 'after' keys" + raise ValueError(msg) + return data + + @field_validator("before", "after", "at", mode="before") + @classmethod + def convert_datetime(cls, value) -> datetime: + if value is None: + return None + return datetime.fromisoformat(value) + + +# TODO: Change class name, does not fit anymore wit hthe addition of `when` and `parameters` +# find something more related to graph specification in general class _LagDateBaseModel(BaseModel): """Base class for all classes containg a list of dates or time lags.""" model_config = ConfigDict(arbitrary_types_allowed=True) date: list[datetime] = [] # this is safe in pydantic lag: list[Duration] = [] # this is safe in pydantic + when: _WhenBaseModel | None = None + parameters: dict | None = None @model_validator(mode="before") @classmethod @@ -75,6 +105,19 @@ def convert_datetimes(cls, value) -> list[datetime]: values = value if isinstance(value, list) else [value] return [datetime.fromisoformat(value) for value in values] + @field_validator("parameters", mode="before") + @classmethod + def check_dict_single_item(cls, params) -> dict: + if params is None: + return None + msg = "parameters must be mappings of a string to a single item" + if not isinstance(params, dict): + raise TypeError(msg) + for k, v in params.items(): + if not isinstance(k, str) or isinstance(v, (list, dict)): + raise TypeError(msg) + return params + class ConfigTask(_NamedBaseModel): """ @@ -153,13 +196,13 @@ class ConfigData(BaseModel): generated: list[ConfigGeneratedData] -class ConfigCycleTaskDepend(_NamedBaseModel, _LagDateBaseModel): +class ConfigCycleTaskWaitOn(_NamedBaseModel, _LagDateBaseModel): """ To create an instance of a input or output in a task in a cycle defined in a workflow file. """ # TODO: Move to "wait_on" keyword in yaml instead of "depend" - name: str # name of the task it depends on + name: str # name of the task it waits on cycle_name: str | None = None @@ -192,7 +235,7 @@ class ConfigCycleTask(_NamedBaseModel): inputs: list[ConfigCycleTaskInput | str] | None = Field(default_factory=list) outputs: list[ConfigCycleTaskOutput | str] | None = Field(default_factory=list) - depends: list[ConfigCycleTaskDepend | str] | None = Field(default_factory=list) + wait_on: list[ConfigCycleTaskWaitOn | str] | None = Field(default_factory=list) @field_validator("inputs", mode="before") @classmethod @@ -220,19 +263,18 @@ def convert_cycle_task_outputs(cls, values) -> list[ConfigCycleTaskOutput]: outputs.append(value) return outputs - @field_validator("depends", mode="before") + @field_validator("wait_on", mode="before") @classmethod - def convert_cycle_task_depends(cls, values) -> list[ConfigCycleTaskDepend]: - depends = [] + def convert_cycle_task_wait_on(cls, values) -> list[ConfigCycleTaskWaitOn]: + wait_on = [] if values is None: - return depends + return wait_on for value in values: if isinstance(value, str): - depends.append({value: None}) + wait_on.append({value: None}) elif isinstance(value, dict): - depends.append(value) - - return depends + wait_on.append(value) + return wait_on class ConfigCycle(_NamedBaseModel): diff --git a/tests/files/configs/test_config_large.yml b/tests/files/configs/test_config_large.yml index 5ff1071..9ce6282 100644 --- a/tests/files/configs/test_config_large.yml +++ b/tests/files/configs/test_config_large.yml @@ -10,20 +10,24 @@ cycles: - icon_bimonthly: start_date: *root_start_date end_date: *root_end_date - period: P2M + period: 'P2M' tasks: - preproc: inputs: [grid_file, extpar_file, ERA5] outputs: [icon_input] - depends: + wait_on: - icon: - lag: -P4M + lag: '-P4M' + when: + after: '2025-03-01T00:00' - icon: inputs: - grid_file - icon_input - icon_restart: lag: '-P2M' + when: + after: *root_start_date outputs: [stream_1, stream_2, icon_restart] - postproc_1: inputs: [stream_1] @@ -34,7 +38,7 @@ cycles: - yearly: start_date: *root_start_date end_date: *root_end_date - period: P1Y + period: 'P1Y' tasks: - postproc_2: inputs: diff --git a/tests/files/configs/test_config_small.yml b/tests/files/configs/test_config_small.yml index 6832ffa..3eccb41 100644 --- a/tests/files/configs/test_config_small.yml +++ b/tests/files/configs/test_config_small.yml @@ -11,11 +11,13 @@ cycles: inputs: - icon_restart: lag: -P2M + when: + after: *root_start_date outputs: [icon_output, icon_restart] - lastly: tasks: - cleanup: - depends: + wait_on: - icon: date: 2026-05-01T00:00 tasks: