Skip to content

Commit

Permalink
ref: use dataclass for Task as well
Browse files Browse the repository at this point in the history
replace custom __init__ by from_config classmethod
  • Loading branch information
leclairm committed Nov 7, 2024
1 parent 69a326d commit ce6fa8e
Showing 1 changed file with 29 additions and 32 deletions.
61 changes: 29 additions & 32 deletions src/sirocco/core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, Literal, TypeVar

from termcolor import colored
Expand Down Expand Up @@ -43,16 +43,17 @@ def __str__(self) -> str:
return f"{self.name} [{self.date}]"


@dataclass
class Task(NodeStr):
"""Internal representation of a task node"""

color: str = "light_red"

name: str
outputs: list[Data]
inputs: list[Data]
wait_on: list[Task]
workflow: Workflow
outputs: list[Data] = field(default_factory=list)
inputs: list[Data] = field(default_factory=list)
wait_on: list[Task] = field(default_factory=list)
date: datetime | None = None
color: str = "light_red"
# TODO: This list is too long. We should start with the set of supported
# keywords and extend it as we support more
command: str | None = None
Expand All @@ -68,35 +69,31 @@ class Task(NodeStr):
src: str | None = None
conda_env: str | None = None

def __init__(self, config: ConfigTask, task_ref: ConfigCycleTask, workflow: Workflow, date: datetime | None = None):
self.name = config.name
self.date = date
self.inputs = []
self.outputs = []
self.wait_on = []
self.workflow = workflow
# Long list of not always supported keywords
self.command = config.command
self.command_option = config.command_option
self.input_arg_options = config.input_arg_options
self.host = config.host
self.account = config.account
self.plugin = config.plugin
self.config = config.config
self.uenv = config.uenv
self.nodes = config.nodes
self.walltime = config.walltime
self.src = config.src
self.conda_env = config.conda_env

# use classmethod instead of custom init
@classmethod
def from_config(cls, config: ConfigTask, task_ref: ConfigCycleTask, workflow: Workflow, date: datetime | None = None) -> Self:
inputs: list[Data] = []
for input_spec in task_ref.inputs:
for data in workflow.data.get(input_spec, self.date):
for data in workflow.data.get(input_spec, date):
if data is not None:
self.inputs.append(data)
inputs.append(data)

outputs: list[Data] = []
for output_spec in task_ref.outputs:
self.outputs.append(self.workflow.data[output_spec.name, self.date])
outputs.append(workflow.data[output_spec.name, date])

new = cls(
date=date,
inputs=inputs,
outputs=outputs,
workflow=workflow,
**dict(config) # use the fact that pydantic models can be turned into dicts easily
) # this works because dataclass has generated this init for us

# Store for actual linking in link_wait_on_tasks() once all tasks are created
self._wait_on_specs = task_ref.depends
new._wait_on_specs = task_ref.depends

return new

def link_wait_on_tasks(self):
for wait_on_spec in self._wait_on_specs:
Expand Down Expand Up @@ -274,7 +271,7 @@ def __init__(self, workflow_config: ConfigWorkflow) -> None:
for task_ref in cycle_config.tasks:
task_name = task_ref.name
task_config = workflow_config.task_dict[task_name]
self.tasks[task_name, date] = (task := Task(task_config, task_ref, workflow=self, date=date))
self.tasks[task_name, date] = (task := Task.from_config(task_config, task_ref, workflow=self, date=date))
cycle_tasks.append(task)
self.cycles[cycle_name, date] = Cycle(name=cycle_name, tasks=cycle_tasks, date=date)

Expand Down

0 comments on commit ce6fa8e

Please sign in to comment.