Skip to content

Commit

Permalink
add(parameters): part 1: when kw and generic ParamSeries
Browse files Browse the repository at this point in the history
the `when` keyword was added to address #40. As a result, keeping
track of dates in the TimeSeries class is not necessary anymore. Still a
more generic ParamSeries container needs to be there for future
parametrized tasks, so it was implemented instead of just ditching
TimeSeries and temporarily use regular dictionaries.
  • Loading branch information
leclairm committed Nov 21, 2024
1 parent 903b3db commit 930f660
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 108 deletions.
8 changes: 4 additions & 4 deletions Parser_PoC/parse_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, name, run_spec):
self.run_spec = run_spec
self.input = []
self.output = []
self.depends = []
self.wait_on = []


class WcData():
Expand Down Expand Up @@ -279,12 +279,12 @@ def _add_edges_from_cycle(self, cycle):
if not out_node.is_concrete():
cluster.append(out_node)
# add dependencies
for dep_spec in task_graph_spec.get('depends', []):
for dep_spec in task_graph_spec.get('wait_on', []):
if (dep_node := self.get_task(dep_spec)):
if isinstance(dep_node, list):
task_node.depends.extend(dep_node)
task_node.wait_on.extend(dep_node)
else:
task_node.depends.append(dep_node)
task_node.wait_on.append(dep_node)
self.add_edge(dep_node, task_node)
# Add clsuter
d1 = self.cycling_date
Expand Down
2 changes: 1 addition & 1 deletion Parser_PoC/test_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ scheduling:
- ERA5
output:
- icon input
depends:
wait_on:
- ICON:
lag: '-P4M'
# lag: '-P6M'
Expand Down
215 changes: 128 additions & 87 deletions src/sirocco/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from sirocco.parsing._yaml_data_models import (
ConfigCycleTask,
ConfigCycleTaskDepend,
ConfigCycleTaskInput,
ConfigCycleTaskWaitOn,
ConfigTask,
ConfigWorkflow,
load_workflow_config,
Expand All @@ -19,13 +19,13 @@

from sirocco.parsing._yaml_data_models import ConfigCycle, DataBaseModel

type ConfigCycleSpec = ConfigCycleTaskDepend | ConfigCycleTaskInput
type ConfigCycleSpec = ConfigCycleTaskWaitOn | ConfigCycleTaskInput

logging.basicConfig()
logger = logging.getLogger(__name__)


TimeSeriesObject = TypeVar("TimeSeriesObject")
StoreObject = TypeVar("StoreObject")


class BaseNode:
Expand Down Expand Up @@ -69,7 +69,7 @@ def from_config(
inputs: list[Data] = []
for input_spec in task_ref.inputs:
inputs.extend(data for data in workflow.data.iter_from_cycle_spec(input_spec, date) if data is not None)
outputs: list[Data] = [workflow.data[output_spec.name, date] for output_spec in task_ref.outputs]
outputs: list[Data] = [workflow.data[output_spec.name, {"date": date}] for output_spec in task_ref.outputs]

new = cls(
date=date,
Expand All @@ -80,7 +80,7 @@ def from_config(
) # this works because dataclass has generated this init for us

# Store for actual linking in link_wait_on_tasks() once all tasks are created
new._wait_on_specs = task_ref.depends # noqa: SLF001 we don't have access to self in a dataclass
new._wait_on_specs = task_ref.wait_on # noqa: SLF001 we don't have access to self in a dataclass
# and setting an underscored attribute from
# the class itself raises SLF001

Expand Down Expand Up @@ -126,121 +126,161 @@ class Cycle(BaseNode):
date: datetime | None = None


class TimeSeries(Generic[TimeSeriesObject]):
"""Dictionnary of objects accessed by date, checking start and end dates"""

def __init__(self) -> None:
self.start_date: datetime | None = None
self.end_date: datetime | None = None
self._dict: dict[str:TimeSeriesObject] = {}

def __setitem__(self, date: datetime, data: TimeSeriesObject) -> None:
if date in self._dict:
msg = f"date {date} already used, cannot set twice"
class ParamSeries(Generic[StoreObject]):
"""Dictionnary of objects accessed by arbitrary parameters"""

def __init__(self, name: str) -> None:
self._name = name
self._dims: set | None = None
self._axes: dict | None = None
self._dict: dict[tuple:StoreObject] | None = None

def __setitem__(self, parameters: dict, value: StoreObject) -> None:
# First access: set axes and initialize dictionnary
param_keys = set(parameters.keys())
if self._dims is None:
self._dims = param_keys
self._axes = {k: set() for k in param_keys}
self._dict = {}
# check dimensions
elif self._dims != param_keys:
msg = (
f"ParamSeries {self._name}: parameter keys {param_keys} don't match ParamSeries dimensions {self._dims}"
)
raise KeyError(msg)
self._dict[date] = data
if self.start_date is None:
self.start_date = date
self.end_date = date
elif date < self.start_date:
self.start_date = date
elif date > self.end_date:
self.end_date = date

def __getitem__(self, date: datetime) -> TimeSeriesObject:
if self.start_date is None:
msg = "TimeSeries still empty, cannot access by date"
raise ValueError(msg)
if date < self.start_date or date > self.end_date:
item = next(iter(self._dict.values()))
# Build internal key
# use the order of self._dims instead of param_keys to ensure reproducibility
key = tuple(parameters[dim] for dim in self._dims)
# Check if slot already taken
if key in self._dict:
msg = f"ParamSeries {self._name}: key {key} already used, cannot set item twice"
raise KeyError(msg)
# Store new axes values
for dim in self._dims:
self._axes[dim].add(parameters[dim])
# Set item
self._dict[key] = value

def __getitem__(self, parameters: dict) -> StoreObject:
if self._dims != (param_keys := set(parameters.keys())):
msg = (
f"date {date} for item '{item.name}' is out of bounds [{self.start_date} - {self.end_date}], ignoring."
f"ParamSeries {self._name}: parameter keys {param_keys} don't match ParamSeries dimensions {self._dims}"
)
logger.warning(msg)
return
if date not in self._dict:
item = next(iter(self._dict.values()))
msg = f"date {date} for item '{item.name}' not found"
raise KeyError(msg)
return self._dict[date]
# use the order of self._dims instead of param_keys to ensure reproducibility
key = tuple(parameters[dim] for dim in self._dims)
return self._dict[key]

def iter_from_cycle_spec(self, spec: ConfigCycleSpec, ref_date: datetime | None = None) -> Iterator[StoreObject]:
# Check date references
if "date" not in self._dims and (spec.lag or spec.date):
msg = f"ParamSeries {self._name} has no date dimension, cannot be referenced by dates"
raise ValueError(msg)
if "date" in self._dims and ref_date is None and spec.date is []:
msg = f"ParamSeries {self._name} has a date dimension, must be referenced by dates"
raise ValueError(msg)
# Generate list of target item keys
keys = [()]
for dim in self._dims:
if dim == "date":
keys = [(*key, date) for key in keys for date in self._resolve_target_dates(spec, ref_date)]
elif dim in spec.parameters:
keys = [(*key, spec.parameters[dim]) for key in keys]
else:
keys = [(*key, item) for key in keys for item in self._axes[dim]]
# Yield items
for key in keys:
yield self._dict[key]

def values(self) -> Iterator[TimeSeriesObject]:
@staticmethod
def _resolve_target_dates(spec: ConfigCycleSpec, ref_date: datetime | None) -> Iterator[datetime]:
if not spec.lag and not spec.date:
yield ref_date
if spec.lag:
for lag in spec.lag:
yield ref_date + lag
if spec.date:
yield from spec.date

def values(self) -> Iterator[StoreObject]:
yield from self._dict.values()


class Store(Generic[TimeSeriesObject]):
"""Container for TimeSeries or unique data"""
class Store(Generic[StoreObject]):
"""Container for ParamSeries or unique items"""

def __init__(self):
self._dict: dict[str, TimeSeries | TimeSeriesObject] = {}
self._dict: dict[str, ParamSeries | StoreObject] = {}

def __setitem__(self, key: str | tuple(str, datetime | None), value: TimeSeriesObject) -> None:
def __setitem__(self, key: str | tuple(str, dict | None), value: StoreObject) -> None:
if isinstance(key, tuple):
name, date = key
name, parameters = key
if "date" in parameters and parameters["date"] is None:
del parameters["date"]
else:
name, date = key, None
name, parameters = key, None
if name in self._dict:
if not isinstance(self._dict[name], TimeSeries):
if not isinstance(self._dict[name], ParamSeries):
msg = f"single entry {name} already set"
raise KeyError(msg)
if date is None:
msg = f"entry {name} is a TimeSeries, must be accessed by date"
if parameters is None:
msg = f"entry {name} is a ParamSeries, must be accessed by parameters"
raise KeyError(msg)
self._dict[name][date] = value
elif date is None:
self._dict[name][parameters] = value
elif not parameters:
self._dict[name] = value
else:
self._dict[name] = TimeSeries()
self._dict[name][date] = value
self._dict[name] = ParamSeries(name)
self._dict[name][parameters] = value

def __getitem__(self, key: str | tuple(str, datetime | None)) -> TimeSeriesObject:
def __getitem__(self, key: str | tuple(str, dict | None)) -> StoreObject:
if isinstance(key, tuple):
name, date = key
name, parameters = key
if "date" in parameters and parameters["date"] is None:
del parameters["date"]
else:
name, date = key, None

name, parameters = key, None
if name not in self._dict:
msg = f"entry {name} not found in Store"
raise KeyError(msg)
if isinstance(self._dict[name], TimeSeries):
if date is None:
msg = f"entry {name} is a TimeSeries, must be accessed by date"
if isinstance(self._dict[name], ParamSeries):
if parameters is None:
msg = f"entry {name} is a ParamSeries, must be accessed by parameters"
raise KeyError(msg)
return self._dict[name][date]
if date is not None:
msg = f"entry {name} is not a TimeSeries, cannot be accessed by date"
return self._dict[name][parameters]
if parameters:
msg = f"entry {name} is not a ParamSeries, cannot be accessed by parameters"
raise KeyError(msg)
return self._dict[name]

@staticmethod
def _resolve_target_dates(spec, ref_date: datetime | None) -> Iterator[datetime]:
if not spec.lag and not spec.date:
yield ref_date
if spec.lag:
for lag in spec.lag:
yield ref_date + lag
if spec.date:
yield from spec.date

def iter_from_cycle_spec(
self, spec: ConfigCycleSpec, ref_date: datetime | None = None
) -> Iterator[TimeSeriesObject]:
name = spec.name
if isinstance(self._dict[name], TimeSeries):
if ref_date is None and spec.date is []:
msg = "TimeSeries object must be referenced by dates"
def iter_from_cycle_spec(self, spec: ConfigCycleSpec, ref_date: datetime | None = None) -> Iterator[StoreObject]:
# Check if target items should be querried at all
if (when := spec.when) is not None:
if ref_date is None:
msg = "Cannot use a `when` specification without a `ref_date`"
raise ValueError(msg)
for target_date in self._resolve_target_dates(spec, ref_date):
yield self._dict[name][target_date]
if (at := when.at) is not None and at != ref_date:
return
if (before := when.before) is not None and before <= ref_date:
return
if (after := when.after) is not None and after >= ref_date:
return
# Yield items
name = spec.name
if isinstance(self._dict[name], ParamSeries):
yield from self._dict[name].iter_from_cycle_spec(spec, ref_date)
else:
if spec.lag or spec.date:
msg = f"item {name} is not a TimeSeries, cannot be referenced via date or lag"
msg = f"item {name} is not a ParamSeries, cannot be referenced by date or lag"
raise ValueError(msg)
if spec.parameters:
msg = f"item {name} is not a ParamSeries, cannot be referenced by parameters"
raise ValueError(msg)
yield self._dict[name]

def values(self) -> Iterator[TimeSeriesObject]:
def values(self) -> Iterator[StoreObject]:
for item in self._dict.values():
if isinstance(item, TimeSeries):
if isinstance(item, ParamSeries):
yield from item.values()
else:
yield item
Expand All @@ -265,7 +305,7 @@ def __init__(self, workflow_config: ConfigWorkflow) -> None:
for data_ref in task_ref.outputs:
data_name = data_ref.name
data_config = workflow_config.data_dict[data_name]
self.data[data_name, date] = Data.from_config(data_config, date=date)
self.data[data_name, {"date": date}] = Data.from_config(data_config, date=date)

# 3 - create cycles and tasks
for cycle_config in workflow_config.cycles:
Expand All @@ -275,11 +315,12 @@ def __init__(self, workflow_config: ConfigWorkflow) -> None:
for task_ref in cycle_config.tasks:
task_name = task_ref.name
task_config = workflow_config.task_dict[task_name]
self.tasks[task_name, date] = (
task = Task.from_config(task_config, task_ref, workflow=self, date=date)
self.tasks[task_name, {"date": date}] = (
task := Task.from_config(task_config, task_ref, workflow=self, date=date)
)
cycle_tasks.append(task)
self.cycles[cycle_name, date] = Cycle(name=cycle_name, tasks=cycle_tasks, date=date)
self.cycles[cycle_name, {"date": date}] = Cycle(name=cycle_name, tasks=cycle_tasks, date=date)

# 4 - Link wait on tasks
for task in self.tasks.values():
Expand Down
Loading

0 comments on commit 930f660

Please sign in to comment.