diff --git a/pyproject.toml b/pyproject.toml index 55e91e8..4b1746d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "springs" -version = "1.10.3" +version = "1.11.0" description = """\ A set of utilities to create and manage typed configuration files \ effectively, built on top of OmegaConf.\ diff --git a/src/springs/commandline.py b/src/springs/commandline.py index 7f52567..b23b169 100644 --- a/src/springs/commandline.py +++ b/src/springs/commandline.py @@ -33,11 +33,12 @@ from .logging import configure_logging from .nicknames import NicknameRegistry from .rich_utils import ( + ConfigTreeParser, RichArgumentParser, + TableParser, add_pretty_traceback, - print_config_as_tree, - print_table, ) +from .types_utils import get_type # parameters for the main function MP = ParamSpec("MP") @@ -305,6 +306,10 @@ def wrap_main_method( # setup logging level for the root logger configure_logging(logging_level="DEBUG" if opts.debug else opts.log_level) + # set up parsers for the various config nodes and tables + tree_parser = ConfigTreeParser() + table_parser = TableParser() + # We don't run the main program if the user # has requested to print the any of the config. do_no_run = ( @@ -320,7 +325,7 @@ def wrap_main_method( # relative import here not to mess things up from .resolvers import all_resolvers - print_table( + table_parser( title="Registered Resolvers", columns=["Resolver Name"], values=[(r,) for r in sorted(all_resolvers())], @@ -329,10 +334,11 @@ def wrap_main_method( "For more information, visit https://omegaconf.readthedocs.io/" "en/latest/custom_resolvers.html" ), + borders=True, ) if opts.nicknames: - print_table( + table_parser( title="Registered Nicknames", columns=["Nickname", "Path"], values=NicknameRegistry().all(), @@ -341,14 +347,16 @@ def wrap_main_method( "${sp.ref:nickname,'path.to.key1=value1',...}. " "\nOverride keys are optional (but quotes are required)." ), + borders=True, ) # Print default options if requested py the user if opts.options: - print_config_as_tree( + config_name = getattr(get_type(config_node), "__name__", None) + tree_parser( title="Default Options", + subtitle=f"(class: '{config_name}')" if config_name else None, config=config_node, - title_color="green", print_help=True, ) @@ -364,10 +372,10 @@ def wrap_main_method( # print the configuration if requested by the user if opts.inputs: - print_config_as_tree( - title=f"Input From File {config_file}", + tree_parser( + title="Input From File", + subtitle=f"(path: '{config_file}')", config=file_config, - title_color="green", print_help=False, ) @@ -379,10 +387,9 @@ def wrap_main_method( # print the configuration if requested by the user if opts.inputs: - print_config_as_tree( + tree_parser( title="Input From Command Line", config=cli_config, - title_color="green", print_help=False, ) @@ -402,10 +409,9 @@ def wrap_main_method( # print it if requested if not (opts.quiet) or opts.parsed: - print_config_as_tree( + tree_parser( title="Parsed Config", config=parsed_config, - title_color="green", print_help=False, ) diff --git a/src/springs/rich_utils.py b/src/springs/rich_utils.py index ceee9a8..3836023 100644 --- a/src/springs/rich_utils.py +++ b/src/springs/rich_utils.py @@ -1,9 +1,11 @@ import os import re from argparse import SUPPRESS, ArgumentParser -from typing import IO, Any, Dict, List, Optional, Sequence, Union +from dataclasses import dataclass +from typing import IO, Any, Dict, Generator, List, Optional, Sequence, Union from omegaconf import DictConfig, ListConfig +from rich import box from rich.console import Console, Group from rich.panel import Panel from rich.style import Style @@ -12,162 +14,113 @@ from rich.traceback import install from rich.tree import Tree +from . import MISSING from .core import traverse from .utils import SpringsConfig - -GREY = "grey74" +__all__ = [ + "RichArgumentParser", + "ConfigTreeParser", + "TableParser", + "add_pretty_traceback", +] + + +def _s( + *, + c: Optional[str] = None, + b: Optional[bool] = None, + i: Optional[bool] = None, + u: Optional[bool] = None, + d: Optional[bool] = None, + r: Optional[bool] = None, + l: Optional[bool] = None, # noqa: E741 +) -> Style: + return Style( + color=c, bold=b, italic=i, underline=u, dim=d, conceal=l, reverse=r + ) -def add_pretty_traceback(**install_kwargs: Any) -> None: - if SpringsConfig.RICH_TRACEBACK_INSTALLED: - return - - # override any default settings if provided - install_kwargs = { - **dict(show_locals=SpringsConfig.RICH_LOCALS), - **install_kwargs, - } +@dataclass +class SpringsTheme: + _rich: bool = Console().color_system not in {"standard", "windows"} + _real: bool = Console().is_terminal - # setup nice traceback through rich library - install(**install_kwargs) + # # # # # # # # # # # # # Configuration Trees # # # # # # # # # # # # # # # - # mark as installed; prevent double installation. - # this is a global setting. - SpringsConfig.RICH_TRACEBACK_INSTALLED = True - - -def print_table( - title: str, - columns: Sequence[str], - values: Sequence[Sequence[Any]], - colors: Optional[Sequence[str]] = None, - caption: Optional[str] = None, -): - colors = list( - colors or ["magenta", "cyan", "red", "green", "yellow", "blue"] + r_title: Style = ( + _s(b=True) if _rich else (_s(u=True, d=False) if _real else _s()) + ) + r_help: Style = ( + _s(c="grey74", b=False, i=True) + if _rich + else (_s(u=False, d=True) if _real else _s()) + ) + r_dict: Style = ( + _s(c="magenta", b=False, i=False, d=False, u=False) if _real else _s() + ) + r_list: Style = ( + _s(c="cyan", b=False, i=False, d=False, u=False) if _real else _s() ) - if len(columns) > len(colors): - # repeat colors if we have more columns than colors - colors = colors * (len(columns) // len(colors) + 1) + r_root: Style = ( + _s(c="green", b=False, i=False, d=False, u=False) if _real else _s() + ) + r_leaf: Style = _s(c="default", b=False, i=False, d=False, u=False) - def _get_longest_row(text: str) -> int: - return max(len(row) for row in text.splitlines()) + # # # # # # # # # # # # # # # # Usage Pane # # # # # # # # # # # # # # # # - min_width = min( - max(_get_longest_row(title), _get_longest_row(caption or "")) + 2, - os.get_terminal_size().columns - 2, - ) + u_bold: Style = _s(b=True) if _rich else _s() + u_title: Style = _s(c="default") + u_bold + u_pane: Style = _s(c="cyan") + u_bold + u_exec: Style = _s(c="green", i=False, d=False, u=False) + u_bold + u_path: Style = _s(c="magenta", i=False, d=False, u=False) + u_bold + u_flag: Style = _s(c="yellow", b=False, i=False, d=False, u=False) + u_para: Style = _s(c="default", b=False, i=False, d=False, u=False) + u_plain: Style = r_leaf - table = Table( - *( - Column(column, justify="center", style=color, vertical="middle") - for column, color in zip(columns, colors) - ), - title=f"\n{title}", - min_width=min_width, - caption=caption, - title_style="bold", - caption_style="grey74", - ) - for row in values: - table.add_row(*row) + # # # # # # # # # # # # # # # Tables Design # # # # # # # # # # # # # # # # - Console().print(table) + t_clr: List[str] = MISSING + t_cnt: int = MISSING + t_head: Style = _s(b=True) if _rich else (_s(r=True) if _real else _s()) + t_body: Style = _s(b=False) if _rich else (_s(r=False) if _real else _s()) + # # # # # # # # # # # # # # # # Box Styles # # # # # # # # # # # # # # # # -def print_config_as_tree( - title: str, - config: Union[DictConfig, ListConfig], - title_color: str = "default", - print_help: bool = False, -): - def get_parent_path(path: str) -> str: - return path.rsplit(".", 1)[0] if "." in path else "" + b_show: box.Box = box.ROUNDED + b_hide: box.Box = box.Box("\n".join(" " * 4 for _ in range(8))) - root = Tree(f"[{title_color}][bold]{title}[/bold][/{title_color}]") - trees: Dict[str, Tree] = {"": root} - nodes_order: Dict[str, Dict[str, int]] = {} + def __post_init__(self): + if self.t_clr is MISSING: + self.t_clr = ["magenta", "yellow", "red", "cyan", "green", "blue"] - # STEP 1: We start by adding all nodes to the tree; a node is a - # DictConfig or ListConfig that has children. - all_nodes = sorted( - traverse(config, include_nodes=True, include_leaves=False), - key=lambda spec: spec.path.count("."), - ) - for spec in all_nodes: - parent_path = get_parent_path(spec.path) - tree = trees.get(parent_path, None) - if spec.key is None or tree is None: - raise ValueError("Cannot print disjoined tree") - - # color is different for DictConfig and ListConfig - l_color = "magenta" if isinstance(spec.value, DictConfig) else "cyan" - l_text = spec.key if isinstance(spec.key, str) else f"[{spec.key}]" - label = f"[bold {l_color}]{l_text}[/bold {l_color}]" - - # Add help if available; make it same color as the key, but italic - # instead of bold. Note that we print the help iff print_help is True. - # We also remove any newlines and extra spaces from the help using - # a regex expression. - if spec.help and print_help: - l_help = re.sub(r"\s+", " ", spec.help.strip()) - label = f"{label}\n[{l_color} italic]({l_help})[/italic {l_color}]" - - # Actually adding the node here! - subtree = tree.add(label=label) - - # We need to keep track of each node in the tree separately; this - # is so that we can attach the leaves to the correct node later. - trees[spec.path] = subtree - - # This helps us remember the order nodes appear in the config - # created by the user. We use this to sort the nodes in the tree - # before printing. - nodes_order.setdefault(parent_path, {})[label] = spec.position - - # STEP 2: We now add all leaves to the tree; a leaf is anything that - # is not a DictConfig or ListConfig. - all_leaves = sorted( - traverse(config, include_nodes=False, include_leaves=True), - key=lambda spec: str(spec.key), - ) - for spec in all_leaves: - parent_path = get_parent_path(spec.path) - tree = trees.get(parent_path, None) - if tree is None: - raise ValueError("Cannot find node for this leaf") + if self.t_cnt is MISSING: + self.t_cnt = len(self.t_clr) - # Using '???' to indicate unknown type - type_name = spec.type.__name__ if spec.type else "???" - label = f"[bold]{spec.key}[/bold] ({type_name}) = {spec.value}" + self.t_clr = self.t_clr * (self.t_cnt // len(self.t_clr) + 1) - # Add help if available; print it a gray color and italic. - if spec.help and print_help: - l_help = re.sub(r"\s+", " ", spec.help.strip()) - label = f"{label}\n[{GREY} italic]({l_help})[/{GREY} italic]" + @property + def t_colors(self) -> Generator[Style, None, None]: + for c in self.t_clr: + yield _s(c=c) - # Actually adding the leaf here! - tree.add(label=label) - # This helps us remember the order leaves appear in the config - # created by the user. We use this to sort the nodes in the tree - # before printing. - nodes_order.setdefault(parent_path, {})[label] = spec.position +def add_pretty_traceback(**install_kwargs: Any) -> None: + if SpringsConfig.RICH_TRACEBACK_INSTALLED: + return - # STEP 3: sort nodes in each tree to match the order the appear - # in the config created by the user. - for l, t in trees.items(): # noqa: E741 - t.children.sort(key=lambda child: nodes_order[l][str(child.label)]) + # override any default settings if provided + install_kwargs = { + **dict(show_locals=SpringsConfig.RICH_LOCALS), + **install_kwargs, + } - # STEP 4: if there are no nodes or leaves in this configuration, add a - # message to the tree that indicates that the config is empty. - if len(all_leaves) == len(all_nodes) == 0: - root = Tree(f"{root.label}\n [{GREY} italic](empty)[/{GREY} italic]") + # setup nice traceback through rich library + install(**install_kwargs) - # time to print! - panel = Panel(root, padding=0, border_style=Style(conceal=True)) - Console().print(panel) + # mark as installed; prevent double installation. + # this is a global setting. + SpringsConfig.RICH_TRACEBACK_INSTALLED = True class RichArgumentParser(ArgumentParser): @@ -181,6 +134,7 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) + self.theme = SpringsTheme() self.entrypoint = entrypoint self.arguments = arguments self.formatted: Dict[str, Any] = {} @@ -212,12 +166,13 @@ def format_usage(self): flags.append(flag.strip()) usage = ( - "[green]python[/green] " - + f"[magenta][bold]{self.entrypoint}[/bold][/magenta] " - + "[yellow]" - + " ".join(flags) - + "[/yellow]" - + f" {self.arguments}" + Text(text="python", style=self.theme.u_exec) + + Text(text=" ", style=self.theme.u_plain) + + Text(text=f"{self.entrypoint}", style=self.theme.u_path) + + Text(text=" ", style=self.theme.u_plain) + + Text(text=" ".join(flags), style=self.theme.u_flag) + + Text(text=" ", style=self.theme.u_plain) + + Text(text=f"{self.arguments}", style=self.theme.u_para) ) else: usage = self.usage @@ -225,7 +180,11 @@ def format_usage(self): if usage is not None: return Panel( usage, - title="[bold][cyan] Usage [cyan][/bold]", + title=( + Text(text=" ", style=self.theme.u_plain) + + Text(text="Usage", style=self.theme.u_pane) + + Text(text=" ", style=self.theme.u_plain) + ), title_align="center", ) @@ -234,9 +193,9 @@ def format_help(self): if self.description: description = Panel( - Text(f"{self.description}", justify="center"), - style=Style(bold=True), - border_style=Style(conceal=True), + Text(self.description, justify="center"), + style=self.theme.u_title, + box=box.SIMPLE, ) groups.append(description) @@ -247,26 +206,10 @@ def format_help(self): if len(ag._group_actions) == 0: continue - table = Table( - show_edge=False, - show_header=False, - border_style=Style(bold=False, conceal=True), - ) - table.add_column( - "Flag", style=Style(color="magenta"), justify="left" - ) - table.add_column( - "Default", style=Style(color="yellow"), justify="center" - ) - table.add_column( - "Action", style=Style(color="red"), justify="center" - ) - table.add_column( - "Description", style=Style(color="green"), justify="left" - ) - table.add_row( - *(f"[bold]{c.header}[/bold]" for c in table.columns), - ) + flags = [] + defaults = [] + actions = [] + descriptions = [] for action in ag._group_actions: if action.default == SUPPRESS or action.default is None: @@ -281,34 +224,271 @@ def format_help(self): else: nargs = str(action.nargs) - table.add_row( - "/".join(action.option_strings), - default, - nargs, - (action.help or ""), - ) + flags.append("/".join(action.option_strings)) + defaults.append(default) + actions.append(nargs) + descriptions.append(action.help or "") + table = TableParser.make_table( + columns=["Flag", "Default", "Action", "Description"], + values=list(zip(flags, defaults, actions, descriptions)), + theme=self.theme, + v_justify=["left", "center", "center", "left"], + ) + title = ( + Text(ag.title.capitalize(), style=self.theme.u_pane) + if ag.title + else None + ) panel = Panel( - table, - title=( - Text( - ag.title.capitalize(), - style=Style(bold=True, color="cyan"), - ) - if ag.title - else None - ), - title_align="center", + table, title=title, title_align="center", box=self.theme.b_show ) groups.append(panel) return Panel( Group(*groups), - border_style=Style(conceal=True), + box=self.theme.b_hide, ) def _print_message( self, message: Any, file: Optional[IO[str]] = None ) -> None: - console = Console(**{**self.console_kwargs, "file": file}) - console.print(message) + Console(**{**self.console_kwargs, "file": file}).print(message) + + +class ConfigTreeParser: + def __init__( + self, + theme: Optional[SpringsTheme] = None, + console_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.theme = theme or SpringsTheme() + self.console_kwargs = console_kwargs or {} + + @classmethod + def _get_parent_path(cls, path: str) -> str: + return path.rsplit(".", 1)[0] if "." in path else "" + + @classmethod + def make_config_tree( + cls, + title: str, + config: Union[DictConfig, ListConfig], + subtitle: Optional[str] = None, + print_help: bool = False, + theme: Optional[SpringsTheme] = None, + ) -> Tree: + theme = theme or SpringsTheme() + + root_label = Text(text=title, style=theme.r_root + theme.r_title) + if subtitle: + root_label += Text( + text=f"\n{subtitle}", style=theme.r_help + theme.r_root + ) + + root = Tree(label=root_label) + trees: Dict[str, Tree] = {"": root} + nodes_order: Dict[str, Dict[str, int]] = {} + + # STEP 1: We start by adding all nodes to the tree; a node is a + # DictConfig or ListConfig that has children. + all_nodes = sorted( + traverse(config, include_nodes=True, include_leaves=False), + key=lambda spec: spec.path.count("."), + ) + for spec in all_nodes: + parent_path = cls._get_parent_path(spec.path) + tree = trees.get(parent_path, None) + if spec.key is None or tree is None: + raise ValueError("Cannot print disjoined tree") + + # # color is different for DictConfig and ListConfig + style = ( + theme.r_dict + if isinstance(spec.value, DictConfig) + else theme.r_list + ) + text = spec.key if isinstance(spec.key, str) else f"[{spec.key}]" + label = Text(text=text, style=style + theme.r_title) + + # Add help if available; make it same color as the key, but italic + # instead of bold. Note that we print the help iff print_help is + # True. We also remove any newlines and extra spaces from the help + # using a regex expression. + if spec.help and print_help: + label += Text( + text="\n" + re.sub(r"\s+", " ", spec.help.strip()), + style=theme.r_help + style, + ) + + # Actually adding the node here! + subtree = tree.add(label=label) + + # We need to keep track of each node in the tree separately; this + # is so that we can attach the leaves to the correct node later. + trees[spec.path] = subtree + + # This helps us remember the order nodes appear in the config + # created by the user. We use this to sort the nodes in the tree + # before printing. + nodes_order.setdefault(parent_path, {})[str(label)] = spec.position + + # STEP 2: We now add all leaves to the tree; a leaf is anything that + # is not a DictConfig or ListConfig. + all_leaves = sorted( + traverse(config, include_nodes=False, include_leaves=True), + key=lambda spec: str(spec.key), + ) + for spec in all_leaves: + parent_path = cls._get_parent_path(spec.path) + tree = trees.get(parent_path, None) + if tree is None: + raise ValueError("Cannot find node for this leaf") + + # Using '???' to indicate unknown type + type_name = spec.type.__name__ if spec.type else "???" + label = ( + Text(text=str(spec.key), style=theme.r_leaf + theme.r_title) + + Text(text=": ", style=theme.r_leaf) + + Text(text=f"({type_name})", style=theme.r_leaf) + + Text(text=" = ", style=theme.r_leaf) + + Text(text=spec.value, style=theme.r_leaf) + ) + + # Add help if available; print it a gray color and italic. + if spec.help and print_help: + label += Text( + text="\n" + re.sub(r"\s+", " ", spec.help.strip()), + style=theme.r_leaf + theme.r_help, + ) + + # Actually adding the leaf here! + tree.add(label=label) + + # This helps us remember the order leaves appear in the config + # created by the user. We use this to sort the nodes in the tree + # before printing. + nodes_order.setdefault(parent_path, {})[str(label)] = spec.position + + # STEP 3: sort nodes in each tree to match the order the appear + # in the config created by the user. + for leaf, tree in trees.items(): # noqa: E741 + tree.children.sort( + key=lambda child: nodes_order[leaf][str(child.label)] + ) + + # STEP 4: if there are no nodes or leaves in this configuration, add a + # message to the tree that indicates that the config is empty. + if len(all_leaves) == len(all_nodes) == 0: + root_label += Text(text="\n [empty]", style=theme.r_help) + root = Tree(label=root_label) + + return root + + def __call__( + self, + config: Union[DictConfig, ListConfig], + title: Optional[str] = None, + subtitle: Optional[str] = None, + print_help: bool = False, + ): + tree = self.make_config_tree( + title=title or "🌳", + config=config, + subtitle=subtitle, + print_help=print_help, + theme=self.theme, + ) + + # time to print! + panel = Panel(tree, padding=0, box=self.theme.b_hide) + Console(**self.console_kwargs).print(panel) + + +class TableParser: + def __init__( + self, + theme: Optional[SpringsTheme] = None, + console_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + self.theme = theme or SpringsTheme() + self.console_kwargs = console_kwargs or {} + + @classmethod + def make_table( + cls, + columns: Sequence[Any], + values: Sequence[Sequence[Any]], + title: Optional[str] = None, + v_justify: Optional[Sequence[str]] = None, + h_justify: Optional[Sequence[str]] = None, + theme: Optional[SpringsTheme] = None, + caption: Optional[str] = None, + borders: bool = False, + ) -> Table: + theme = theme or SpringsTheme() + v_justify = v_justify or ["center"] * len(columns) + h_justify = h_justify or ["middle"] * len(columns) + + def _get_longest_row(text: str) -> int: + return max(len(row) for row in (text.splitlines() or [""])) + + min_width_outside_content = min( + max(_get_longest_row(title or ""), _get_longest_row(caption or "")) + + 2, + os.get_terminal_size().columns - 2, + ) + + columns = ( + Column( + header=f" {cl} ", + justify=vj, # type: ignore + style=co + theme.t_body, + header_style=co + theme.t_head, + vertical=hj, # type: ignore + ) + for cl, vj, hj, co in zip( + columns, v_justify, h_justify, theme.t_colors + ) + ) + + table = Table( + *columns, + padding=(0, 0), + title=f"\n{title}" if title else None, + min_width=min_width_outside_content, + caption=caption, + title_style=theme.r_title, + caption_style=theme.r_help, + box=(theme.b_show if borders else theme.b_hide), + expand=True, + collapse_padding=True, + ) + for row in values: + table.add_row(*row) + + return table + + def __call__( + self, + columns: Sequence[Any], + values: Sequence[Sequence[Any]], + title: Optional[str] = None, + v_justify: Optional[Sequence[str]] = None, + h_justify: Optional[Sequence[str]] = None, + caption: Optional[str] = None, + borders: bool = False, + ) -> None: + table = self.make_table( + columns=columns, + values=values, + title=title, + v_justify=v_justify, + h_justify=h_justify, + theme=self.theme, + caption=caption, + borders=borders, + ) + # time to print! + panel = Panel(table, padding=0, box=self.theme.b_hide) + Console(**self.console_kwargs).print(panel)