Skip to content

Commit

Permalink
Fuse classes for wait_on tasks and input data (#56)
Browse files Browse the repository at this point in the history
`ConfigCycleTaskInput` and `ConfigCycleTaskWaitOn` classes were gathering the same required information for targeting other nodes. No need to have 2.
  • Loading branch information
leclairm authored Dec 2, 2024
1 parent 897faca commit 4d4e501
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 111 deletions.
12 changes: 4 additions & 8 deletions src/sirocco/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from sirocco.parsing._yaml_data_models import (
ConfigCycleTask,
ConfigCycleTaskInput,
ConfigCycleTaskWaitOn,
ConfigTask,
ConfigWorkflow,
load_workflow_config,
Expand All @@ -18,9 +16,7 @@
from collections.abc import Iterator
from datetime import datetime

from sirocco.parsing._yaml_data_models import ConfigCycle, DataBaseModel

type ConfigCycleSpec = ConfigCycleTaskWaitOn | ConfigCycleTaskInput
from sirocco.parsing._yaml_data_models import ConfigCycle, DataBaseModel, TargetNodesBaseModel

logging.basicConfig()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -185,7 +181,7 @@ def __getitem__(self, coordinates: dict) -> GraphItem:
key = tuple(coordinates[dim] for dim in self._dims)
return self._dict[key]

def iter_from_cycle_spec(self, spec: ConfigCycleSpec, reference: dict) -> Iterator[GraphItem]:
def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GraphItem]:
# Check date references
if "date" not in self._dims and (spec.lag or spec.date):
msg = f"Array {self._name} has no date dimension, cannot be referenced by dates"
Expand All @@ -197,7 +193,7 @@ def iter_from_cycle_spec(self, spec: ConfigCycleSpec, reference: dict) -> Iterat
for key in product(*(self._resolve_target_dim(spec, dim, reference) for dim in self._dims)):
yield self._dict[key]

def _resolve_target_dim(self, spec: ConfigCycleSpec, dim: str, reference: Any) -> Iterator[Any]:
def _resolve_target_dim(self, spec: TargetNodesBaseModel, dim: str, reference: Any) -> Iterator[Any]:
if dim == "date":
if not spec.lag and not spec.date:
yield reference["date"]
Expand Down Expand Up @@ -239,7 +235,7 @@ def __getitem__(self, key: tuple[str, dict]) -> GraphItem:
raise KeyError(msg)
return self._dict[name][coordinates]

def iter_from_cycle_spec(self, spec: ConfigCycleSpec, reference: dict) -> Iterator[GraphItem]:
def iter_from_cycle_spec(self, spec: TargetNodesBaseModel, reference: dict) -> Iterator[GraphItem]:
# Check if target items should be querried at all
if (when := spec.when) is not None:
if (ref_date := reference.get("date")) is None:
Expand Down
191 changes: 88 additions & 103 deletions src/sirocco/parsing/_yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,13 @@ def convert_datetime(cls, value) -> datetime:
return datetime.fromisoformat(value)


# TODO: Change class name, does not fit anymore wit hthe addition of `when` and `parameters`
# find something more related to graph specification in general like _GraphTargetBaseModel
class _LagDateBaseModel(BaseModel):
"""Base class for all classes containg a list of dates or time lags."""
class TargetNodesBaseModel(_NamedBaseModel):
"""class for targeting other task or data nodes in the graph
When specifying cycle tasks, this class gathers the required information for
targeting other nodes, either input data or wait on tasks.
"""

model_config = ConfigDict(arbitrary_types_allowed=True)
date: list[datetime] = [] # this is safe in pydantic
Expand Down Expand Up @@ -117,111 +120,14 @@ def check_dict_single_item(cls, params: dict) -> dict:
return params


class ConfigTask(_NamedBaseModel):
"""
To create an instance of a task defined in a workflow file
"""

# TODO: This list is too large. We should start with the set of supported
# keywords and extend it as we support more
command: str
command_option: str | None = None
input_arg_options: dict[str, str] | None = None
parameters: list[str] = []
host: str | None = None
account: str | None = None
plugin: str | None = None
config: str | None = None
uenv: dict | None = None
nodes: int | None = None
walltime: str | None = None
src: str | None = None
conda_env: str | None = None

def __init__(self, /, **data):
# We have to treat root special as it does not typically define a command
if "ROOT" in data and "command" not in data["ROOT"]:
data["ROOT"]["command"] = "ROOT_PLACEHOLDER"
super().__init__(**data)

@field_validator("command")
@classmethod
def expand_env_vars(cls, value: str) -> str:
"""Expands any environment variables in the value"""
return expandvars(value)

@field_validator("walltime")
@classmethod
def convert_to_struct_time(cls, value: str | None) -> time.struct_time | None:
"""Converts a string of form "%H:%M:%S" to a time.time_struct"""
return None if value is None else time.strptime(value, "%H:%M:%S")


class DataBaseModel(_NamedBaseModel):
"""
To create an instance of a data defined in a workflow file.
"""

type: str
src: str
format: str | None = None
parameters: list[str] = []

@field_validator("type")
@classmethod
def is_file_or_dir(cls, value: str) -> str:
"""."""
if value not in ["file", "dir"]:
msg = "Must be one of 'file' or 'dir'."
raise ValueError(msg)
return value

@property
def available(self) -> bool:
return isinstance(self, ConfigAvailableData)


class ConfigAvailableData(DataBaseModel):
class ConfigCycleTaskInput(TargetNodesBaseModel):
pass


class ConfigGeneratedData(DataBaseModel):
class ConfigCycleTaskWaitOn(TargetNodesBaseModel):
pass


class ConfigData(BaseModel):
"""To create the container of available and generated data"""

available: list[ConfigAvailableData] = []
generated: list[ConfigGeneratedData] = []


class ConfigCycleTaskWaitOn(_NamedBaseModel, _LagDateBaseModel):
"""
To create an instance of a input or output in a task in a cycle defined in a workflow file.
"""

# TODO: Move to "wait_on" keyword in yaml instead of "depend"
name: str # name of the task it waits on
cycle_name: str | None = None


class ConfigCycleTaskInput(_NamedBaseModel, _LagDateBaseModel):
"""
To create an instance of an input in a task in a cycle defined in a workflow file.
For example:
.. yaml
- my_input:
date: ...
lag: ...
"""

arg_option: str | None = None


class ConfigCycleTaskOutput(_NamedBaseModel):
"""
To create an instance of an output in a task in a cycle defined in a workflow file.
Expand Down Expand Up @@ -324,6 +230,85 @@ def check_period_is_not_negative_or_zero(self) -> ConfigCycle:
return self


class ConfigTask(_NamedBaseModel):
"""
To create an instance of a task defined in a workflow file
"""

# TODO: This list is too large. We should start with the set of supported
# keywords and extend it as we support more
command: str
command_option: str | None = None
input_arg_options: dict[str, str] | None = None
parameters: list[str] = []
host: str | None = None
account: str | None = None
plugin: str | None = None
config: str | None = None
uenv: dict | None = None
nodes: int | None = None
walltime: str | None = None
src: str | None = None
conda_env: str | None = None

def __init__(self, /, **data):
# We have to treat root special as it does not typically define a command
if "ROOT" in data and "command" not in data["ROOT"]:
data["ROOT"]["command"] = "ROOT_PLACEHOLDER"
super().__init__(**data)

@field_validator("command")
@classmethod
def expand_env_vars(cls, value: str) -> str:
"""Expands any environment variables in the value"""
return expandvars(value)

@field_validator("walltime")
@classmethod
def convert_to_struct_time(cls, value: str | None) -> time.struct_time | None:
"""Converts a string of form "%H:%M:%S" to a time.time_struct"""
return None if value is None else time.strptime(value, "%H:%M:%S")


class DataBaseModel(_NamedBaseModel):
"""
To create an instance of a data defined in a workflow file.
"""

type: str
src: str
format: str | None = None
parameters: list[str] = []

@field_validator("type")
@classmethod
def is_file_or_dir(cls, value: str) -> str:
"""."""
if value not in ["file", "dir"]:
msg = "Must be one of 'file' or 'dir'."
raise ValueError(msg)
return value

@property
def available(self) -> bool:
return isinstance(self, ConfigAvailableData)


class ConfigAvailableData(DataBaseModel):
pass


class ConfigGeneratedData(DataBaseModel):
pass


class ConfigData(BaseModel):
"""To create the container of available and generated data"""

available: list[ConfigAvailableData] = []
generated: list[ConfigGeneratedData] = []


class ConfigWorkflow(BaseModel):
name: str | None = None
cycles: list[ConfigCycle]
Expand Down

0 comments on commit 4d4e501

Please sign in to comment.