Skip to content

Commit

Permalink
separate pretty printing concerns from graph nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
DropD committed Nov 13, 2024
1 parent 354ebb5 commit 330986d
Show file tree
Hide file tree
Showing 5 changed files with 641 additions and 81 deletions.
74 changes: 11 additions & 63 deletions src/sirocco/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, Literal, Self, TypeVar

from termcolor import colored
from typing import TYPE_CHECKING, Generic, Self, TypeVar

from sirocco.parsing._yaml_data_models import (
ConfigCycleTask,
Expand All @@ -30,23 +28,13 @@
TimeSeriesObject = TypeVar("TimeSeriesObject")


class NodeStr:
class BaseNode:
name: str
color: str

def _str_pretty_(self) -> str:
repr_str = colored(self.name, self.color, attrs=["bold"])
if self.date is not None:
repr_str += colored(f" [{self.date}]", self.color)
return repr_str

def __str__(self) -> str:
if self.date is None:
return self.name
return f"{self.name} [{self.date}]"


@dataclass
class Task(NodeStr):
class Task(BaseNode):
name: str
workflow: Workflow
outputs: list[Data] = field(default_factory=list)
Expand All @@ -72,7 +60,11 @@ class Task(NodeStr):
# use classmethod instead of custom init
@classmethod
def from_config(
cls, config: ConfigTask, task_ref: ConfigCycleTask, workflow: Workflow, date: datetime | None = None
cls,
config: ConfigTask,
task_ref: ConfigCycleTask,
workflow: Workflow,
date: datetime | None = None,
) -> Self:
inputs: list[Data] = []
for input_spec in task_ref.inputs:
Expand Down Expand Up @@ -103,7 +95,7 @@ def link_wait_on_tasks(self):


@dataclass(kw_only=True)
class Data(NodeStr):
class Data(BaseNode):
"""Internal representation of a data node"""

color: str = "light_blue"
Expand All @@ -125,7 +117,7 @@ def from_config(cls, config: DataBaseModel, *, date: datetime | None = None):


@dataclass(kw_only=True)
class Cycle(NodeStr):
class Cycle(BaseNode):
"""Internal reprenstation of a cycle"""

color: str = "light_green"
Expand Down Expand Up @@ -299,50 +291,6 @@ def cycle_dates(self, cycle_config: ConfigCycle) -> Iterator[datetime]:
while (date := date + cycle_config.period) < cycle_config.end_date:
yield date

def _str_from_method(self, method_name: Literal["__str__", "_str_pretty_"]) -> str:
str_method = getattr(NodeStr, method_name)
ind = ""
lines = []
lines.append(f"{ind}cycles:")
ind += " "
for cycle in self.cycles.values():
lines.append(f"{ind}- {str_method(cycle)}:")
ind += " "
lines.append(f"{ind}tasks:")
ind += " "
for task in cycle.tasks:
lines.append(f"{ind}- {str_method(task)}:")
ind += " "
if task.inputs:
lines.append(f"{ind}input:")
ind += " "
lines.extend(f"{ind}- {str_method(data)}" for data in task.inputs)
ind = ind[:-2]
if task.outputs:
lines.append(f"{ind}output:")
ind += " "
lines.extend(f"{ind}- {str_method(data)}" for data in task.outputs)
ind = ind[:-2]
if task.wait_on:
lines.append(f"{ind}wait on:")
ind += " "
lines.extend(f"{ind}- {str_method(wait_task)}" for wait_task in task.wait_on)
ind = ind[:-2]
ind = ind[:-4]
ind = ind[:-4]
ind = ind[:-2]
ind = ind[:-2]
return "\n".join(lines)

def __str__(self):
return self._str_from_method("__str__")

def _str_pretty_(self):
return self._str_from_method("_str_pretty_")

def _repr_pretty_(self, p, cycle):
p.text(self._str_pretty_() if not cycle else "...")

@classmethod
def from_yaml(cls, config_path: str):
return cls(load_workflow_config(config_path))
140 changes: 140 additions & 0 deletions src/sirocco/pretty_print.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import dataclasses
import functools
import textwrap
from typing import Any

from termcolor import colored

from . import core


@dataclasses.dataclass
class PrettyPrinter:
"""
Pretty print unrolled workflow graph elements in a reproducable and human readable format.
This can be used to compare workflow graphs by their string representation.
Colored output can be enabled by setting ".colors" to True, this will use terminal control characters,
which makes it less suited for uses other than human viewing.
"""

indentation: int = 2 # how many spaces to indent block content by
colors: bool = False # True for color output (term control chars)

def indent(self, string: str) -> str:
"""Indent by the amount set on the instance"""
return textwrap.indent(string, prefix=" " * self.indentation)

def as_block(self, header: str, body: str) -> str:
"""
Format as a block with a header line and indented lines of block body text.
Example:
>>> print(PrettyPrinter().as_block("header", "foo\nbar"))
header:
foo
bar
"""
return f"{header}:\n{self.indent(body)}"

def as_item(self, content: str) -> str:
"""
Format as an item in an unordered list.
Works for single lines as well as multi line (block) content.
Example:
>>> print(PrettyPrinter().as_item("foo"))
- foo
>>> pp = PrettyPrinter()
>>> print(pp.as_item(pp.as_block("header", "multiple\nlines\nof text")))
- header:
multiple
lines
of text
"""
if not str:
return "- "
lines = content.splitlines()
if len(lines) == 1:
return f"- {content}"
header = lines[0]
body = textwrap.indent("\n".join(lines[1:]), prefix=" ")
return f"- {header}\n{body}"

@functools.singledispatchmethod
def format(self, obj: Any):
"""
Dispatch formatting based on node type.
Default implementation simply calls str()
"""
return str(obj)

@format.register
def format_basic(self, obj: core.BaseNode) -> str:
"""
Default formatting for BaseNode.
Can also be used explicitly to get a single line representation of any node.
Example:
>>> from datetime import datetime
>>> print(
... PrettyPrinter().format_basic(
... Task(name=foo, date=datetime(1000, 1, 1).date(), workflow=None)
... )
... )
foo [1000-01-01]
"""
name = obj.name
date = f"[{obj.date}]" if obj.date else None
if self.colors:
name = colored(name, obj.color, attrs=["bold"])
date = colored(date, obj.color) if date else None
return f"{name} {date}" if date else name

@format.register
def format_workflow(self, obj: core.Workflow) -> str:
cycles = "\n".join(self.format(cycle) for cycle in obj.cycles.values())
return self.as_block("cycles", cycles)

@format.register
def format_cycle(self, obj: core.Cycle) -> str:
tasks = self.as_block("tasks", "\n".join(self.format(task) for task in obj.tasks))
return self.as_item(self.as_block(self.format_basic(obj), tasks))

@format.register
def format_task(self, obj: core.Task) -> str:
sections = []
if obj.inputs:
sections.append(
self.as_block(
"input",
"\n".join(self.as_item(self.format_basic(inp)) for inp in obj.inputs),
)
)
if obj.outputs:
sections.append(
self.as_block(
"output",
"\n".join(self.as_item(self.format_basic(output)) for output in obj.outputs),
)
)
if obj.wait_on:
sections.append(
self.as_block(
"wait on",
"\n".join(self.as_item(self.format_basic(waiton)) for waiton in obj.wait_on),
)
)
return self.as_item(
self.as_block(
self.format_basic(obj),
"\n".join(sections),
)
)
Loading

0 comments on commit 330986d

Please sign in to comment.