From 49eee4eb7605e1fb2ce03ab664f2a671d02a3f14 Mon Sep 17 00:00:00 2001 From: Hannes Weichelt Date: Tue, 4 Jun 2024 18:05:11 +0200 Subject: [PATCH] fixed linting and typing --- src/clingexplaid/__main__.py | 2 - src/clingexplaid/cli/clingo_app.py | 18 +- src/clingexplaid/cli/textual_gui.py | 267 ++++++++++++------ src/clingexplaid/cli/textual_style.py | 4 + src/clingexplaid/propagators/__init__.py | 2 - .../propagators/propagator_decision_order.py | 176 ------------ .../propagator_solver_decisions.py | 42 ++- 7 files changed, 212 insertions(+), 299 deletions(-) delete mode 100644 src/clingexplaid/propagators/propagator_decision_order.py diff --git a/src/clingexplaid/__main__.py b/src/clingexplaid/__main__.py index 6719b11..7bd17d7 100644 --- a/src/clingexplaid/__main__.py +++ b/src/clingexplaid/__main__.py @@ -4,12 +4,10 @@ import sys -import clingo from clingo.application import clingo_main from .cli.clingo_app import ClingoExplaidApp from .cli.textual_gui import textual_main -from .propagators import SolverDecisionPropagator RUN_TEXTUAL_GUI = False diff --git a/src/clingexplaid/cli/clingo_app.py b/src/clingexplaid/cli/clingo_app.py index 9f6c677..faa48d0 100644 --- a/src/clingexplaid/cli/clingo_app.py +++ b/src/clingexplaid/cli/clingo_app.py @@ -13,7 +13,6 @@ from clingo.application import Application, Flag from ..mus import CoreComputer -from ..propagators import DecisionOrderPropagator from ..transformers import AssumptionTransformer, OptimizationRemover from ..unsat_constraints import UnsatConstraintComputer from ..utils import get_constants_from_arguments @@ -202,7 +201,6 @@ def _method_mus( files=files, assumption_string=mus_string, output_prefix_active=f"{COLORS['RED']}├──{COLORS['NORMAL']}", - output_prefix_passive=f"{COLORS['RED']}│ {COLORS['NORMAL']}", ) # Case: Finding multiple MUS @@ -225,7 +223,6 @@ def _method_mus( files=files, assumption_string=mus_string, output_prefix_active=f"{COLORS['RED']}├──{COLORS['NORMAL']}", - output_prefix_passive=f"{COLORS['RED']}│ {COLORS['NORMAL']}", ) if not n_mus: print( @@ -274,14 +271,7 @@ def _method_unsat_constraints( files: List[str], assumption_string: Optional[str] = None, output_prefix_active: str = "", - output_prefix_passive: str = "", ) -> None: - # register DecisionOrderPropagator if flag is enabled - if self.method_flags["show-decisions"]: - decision_signatures = set(self._show_decisions_decision_signatures.items()) - dop = DecisionOrderPropagator(signatures=decision_signatures, prefix=output_prefix_passive) - control.register_propagator(dop) # type: ignore - ucc = UnsatConstraintComputer(control=control) ucc.parse_files(files) unsat_constraints = ucc.get_unsat_constraints(assumption_string=assumption_string) @@ -309,11 +299,9 @@ def _method_show_decisions( control: clingo.Control, files: List[str], ) -> None: - app = ClingexplaidTextualApp( - files=files, - constants={}, - signatures=set(), - ) + print(control) # only for pylint + + app = ClingexplaidTextualApp(files=files, constants={}) app.run() def print_model(self, model: clingo.Model, _) -> None: # type: ignore diff --git a/src/clingexplaid/cli/textual_gui.py b/src/clingexplaid/cli/textual_gui.py index bdce5cb..1680dea 100644 --- a/src/clingexplaid/cli/textual_gui.py +++ b/src/clingexplaid/cli/textual_gui.py @@ -7,9 +7,10 @@ import itertools import re from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union, cast import clingo +from rich.text import Text from textual import on from textual.app import App, ComposeResult from textual.containers import HorizontalScroll, Vertical, VerticalScroll @@ -28,19 +29,51 @@ TabPane, Tree, ) +from textual.widgets.tree import TreeNode from ..propagators import SolverDecisionPropagator -from ..propagators.propagator_solver_decisions import INTERNAL_STRING +from ..propagators.propagator_solver_decisions import INTERNAL_STRING, Decision from .textual_style import MAIN_CSS ACTIVE_CLASS = "active" +def read_file(path: Union[Path, str]) -> str: + """ + Helper function to get the contents of a file as a string. + """ + file_content = "" + with open(path, "r", encoding="utf-8") as f: + file_content = f.read() + return file_content + + +def flatten_list(ls: Optional[List[List[Any]]]) -> List[Any]: + """ + Helper function to flatten a list + """ + if ls is None: + ls = [] + return list(itertools.chain.from_iterable(ls)) + + +def parse_constants(constant_strings: List[str]) -> Dict[str, str]: + """ + Helper function to parse constants + """ + constants = {} + for const_string in constant_strings: + result = re.search(r"(^[a-zA-Z_][a-zA-Z0-9_]*)=([a-zA-Z_][a-zA-Z0-9_]*|[0-9]+)$", const_string) + if result is not None: + constants[result.group(1)] = result.group(2) + return constants + + class SelectorWidget(Static): """SelectorWidget Field""" - def __init__(self, compose_widgets: List[Widget], update_value_function: Callable) -> None: - super(SelectorWidget, self).__init__() + def __init__(self, compose_widgets: List[Widget], update_value_function: Callable[[Any], str]) -> None: + super().__init__() self.compose_widgets = compose_widgets self.active = True self.value = "" @@ -49,15 +82,24 @@ def __init__(self, compose_widgets: List[Widget], update_value_function: Callabl self.set_active_class() def toggle_active(self) -> None: + """ + Toggles the `SelectorWidget`'s active property. + """ self.active = not self.active if self.active: self.apply_value_function() self.set_active_class() - def apply_value_function(self): + def apply_value_function(self) -> None: + """ + Applies the on __init__ provided `update_value_function` to compute `SelectorWidget.value` + """ self.value = self.update_value_function(self) - def set_active_class(self): + def set_active_class(self) -> None: + """ + Sets the active class of the `SelectorWidget` according to `SelectorWidget.active` + """ if self.active: if ACTIVE_CLASS not in self.classes: self.add_class(ACTIVE_CLASS) @@ -65,12 +107,17 @@ def set_active_class(self): self.remove_class(ACTIVE_CLASS) def compose(self) -> ComposeResult: + """ + Composes the `SelectorWidget`'s components + """ yield Checkbox(value=True) - for element in self.compose_widgets: - yield element + yield from self.compose_widgets @on(Checkbox.Changed) async def selector_changed(self, event: Checkbox.Changed) -> None: + """ + Callback for when the `SelectorWidget`'s Checkbox is changed. + """ # Updating the UI to show the reasons why validation failed if event.checkbox == self.query_one(Checkbox): self.toggle_active() @@ -80,8 +127,8 @@ async def selector_changed(self, event: Checkbox.Changed) -> None: class LabelInputWidget(SelectorWidget): """LabelInputWidget Field""" - def __init__(self, name: str, value: str, update_value_function: Callable) -> None: - super(LabelInputWidget, self).__init__( + def __init__(self, name: str, value: str, update_value_function: Callable[[SelectorWidget], str]) -> None: + super().__init__( compose_widgets=[ Label(name), Input(placeholder="Value", value=value), @@ -93,8 +140,8 @@ def __init__(self, name: str, value: str, update_value_function: Callable) -> No class LabelWidget(SelectorWidget): """LabelWidget Field""" - def __init__(self, path: str, update_value_function: Callable) -> None: - super(LabelWidget, self).__init__( + def __init__(self, path: str, update_value_function: Callable[[SelectorWidget], str]) -> None: + super().__init__( compose_widgets=[ HorizontalScroll(Label(path)), ], @@ -105,64 +152,91 @@ def __init__(self, path: str, update_value_function: Callable) -> None: class SelectorList(Static): """Widget for selecting the program files""" - def __init__(self, selectors: Optional[Iterable], classes: str = "") -> None: - super(SelectorList, self).__init__(classes=classes) + def __init__(self, selectors: Optional[Iterable[Any]], classes: str = "") -> None: + super().__init__(classes=classes) self.add_class("selectors") if selectors is None: selectors = [] self.selectors = selectors def get_selectors(self) -> List[SelectorWidget]: + """ + Base function for getting selectors. This should be overwritten in any classes that inherit from this class. + """ return [] def compose(self) -> ComposeResult: + """ + Composes the `SelectorList`'s components + """ yield VerticalScroll( *self.get_selectors(), ) class ConstantsWidget(SelectorList): + """List Widget for Constants""" def __init__(self, constants: Optional[Dict[str, str]]) -> None: - super(ConstantsWidget, self).__init__(selectors={} if constants is None else constants) + super().__init__(selectors={} if constants is None else constants) - def get_selectors(self) -> List[LabelInputWidget]: + def get_selectors(self) -> List[SelectorWidget]: + """ + Fill the `ConstantsWidget` with `LabelInputWidget`s for each constant. + """ return [ LabelInputWidget(name, value, update_value_function=self.update_value) for name, value in dict(self.selectors).items() ] @staticmethod - def update_value(selector_object): + def update_value(selector_object: SelectorWidget) -> str: + """ + Updates the value for each constant with its name and value + """ label_string = str(selector_object.query_one(Label).renderable).strip() input_string = str(selector_object.query_one(Input).value).strip() return f"#const {label_string}={input_string}." class FilesWidget(SelectorList): + """List Widget for Files""" def __init__(self, files: Optional[List[str]]) -> None: - super(FilesWidget, self).__init__(selectors=[] if files is None else files) + super().__init__(selectors=[] if files is None else files) - def get_selectors(self) -> List[LabelWidget]: + def get_selectors(self) -> List[SelectorWidget]: + """ + Fill the `FilesWidget` with `LabelWidget`s for each file. + """ return [LabelWidget(name, update_value_function=self.update_value) for name in self.selectors] @staticmethod - def update_value(selector_object): + def update_value(selector_object: SelectorWidget) -> str: + """ + Updates the value for each file with its name + """ label_string = str(selector_object.query_one(Label).renderable).strip() return label_string class SignaturesWidget(SelectorList): + """List Widget for Signatures""" def __init__(self, signatures: Optional[List[str]]) -> None: - super(SignaturesWidget, self).__init__(selectors=[] if signatures is None else signatures) + super().__init__(selectors=[] if signatures is None else signatures) - def get_selectors(self) -> List[LabelWidget]: + def get_selectors(self) -> List[SelectorWidget]: + """ + Fill the `SignaturesWidget` with `LabelWidget`s for each signature. + """ return [LabelWidget(name, update_value_function=self.update_value) for name in self.selectors] @staticmethod - def update_value(selector_object): + def update_value(selector_object: SelectorWidget) -> str: + """ + Updates the value for each file with its name and arity + """ label_string = str(selector_object.query_one(Label).renderable).strip() return label_string @@ -174,17 +248,19 @@ def __init__( self, files: List[str], constants: Optional[Dict[str, str]], - signatures: Optional[Set[Tuple[str, int]]], classes: str = "", ) -> None: - super(Sidebar, self).__init__(classes=classes) + super().__init__(classes=classes) self.files = files self.constants = {} if constants is None else constants self.signatures = self.get_all_program_signatures() def get_all_program_signatures(self) -> Set[Tuple[str, int]]: - # TODO: This is done with grounding rn but doing a text processing would probably be more efficient for large - # programs! + """ + Get all signatures occurring in all files provided. + """ + # This is done with grounding rn but doing a text processing would probably be more efficient for large + # programs! ctl = clingo.Control() for file in self.files: ctl.load(file) @@ -192,6 +268,9 @@ def get_all_program_signatures(self) -> Set[Tuple[str, int]]: return {(name, arity) for name, arity, _ in ctl.symbolic_atoms.signatures} def compose(self) -> ComposeResult: + """ + Composes the `Sidebar`'s components + """ with TabbedContent(): with TabPane("Files"): yield FilesWidget(self.files) @@ -206,6 +285,9 @@ class ControlPanel(Static): """Widget for the clingexplaid sidebar""" def compose(self) -> ComposeResult: + """ + Composes the `ControlPanel`'s components + """ yield Label("Mode") yield Select(((line, line) for line in ["SHOW DECISIONS"]), allow_blank=False) yield Label("Models") @@ -223,19 +305,27 @@ def compose(self) -> ComposeResult: @on(Input.Changed) async def input_changed(self, event: Input.Changed) -> None: + """ + Callback for when the `ControlPanel`'s Input is changed. + """ # Updating the UI to show the reasons why validation failed if event.input == self.query_one("#model-number-input"): + if event.validation_result is None: + return if not event.validation_result.is_valid: self.add_class("error") first_error = event.validation_result.failure_descriptions[0] - self.query_one("Label.error").update(first_error) + cast(Label, self.query_one("Label.error")).update(first_error) else: self.remove_class("error") - self.query_one("Label.error").update("") + cast(Label, self.query_one("Label.error")).update("") await self.run_action("update_config") @on(Button.Pressed) async def solve(self, event: Button.Pressed) -> None: + """ + Callback for when the `ControlPanel`'s Button is changed. + """ if event.button == self.query_one("#solve-button"): await self.run_action("solve") @@ -244,48 +334,47 @@ class SolverTreeView(Static): """Widget for the clingexplaid show decisions tree""" def __init__(self, classes: str = "") -> None: - super(SolverTreeView, self).__init__(classes=classes) - self.solve_tree = Tree("Solver Decisions", id="explanation-tree") + super().__init__(classes=classes) + self.solve_tree: Tree[str] = Tree("Solver Decisions", id="explanation-tree") def compose(self) -> ComposeResult: + """ + Composes the `SolverTreeView`'s components + """ self.solve_tree.root.expand() yield self.solve_tree yield LoadingIndicator() -def read_file(path: Union[Path, str]) -> str: - file_content = "" - with open(path, "r", encoding="utf-8") as f: - file_content = f.read() - return file_content - - class ClingexplaidTextualApp(App[int]): """A textual app for a terminal GUI to use the clingexplaid functionality""" + # pylint: disable=too-many-instance-attributes + BINDINGS = [ ("ctrl+x", "exit", "Exit"), - # ("ctrl+s", "solve", "Solve"), ] CSS = MAIN_CSS - def __init__(self, files: List[str], constants: Dict[str, str], signatures: Set[Tuple[str, int]]) -> None: - super(ClingexplaidTextualApp, self).__init__() + def __init__(self, files: List[str], constants: Dict[str, str]) -> None: + super().__init__() self.files = files self.constants = constants - self.signatures = signatures - self.tree_cursor = None + self.tree_cursor: Optional[TreeNode[str]] = None self.model_count = 0 - self.config_model_number = 1 - self.config_show_internal = True - self.loaded_files = set() - self.loaded_signatures = set() + self._config_model_number = 1 + self._config_show_internal = True + self._loaded_files: Set[str] = set() + self._loaded_signatures: Set[Tuple[str, int]] = set() def compose(self) -> ComposeResult: + """ + Composes the `ClingexplaidTextualApp`'s components + """ yield Vertical( ControlPanel(classes="box"), - Sidebar(files=self.files, constants=self.constants, signatures=self.signatures, classes="box tabs"), + Sidebar(files=self.files, constants=self.constants, classes="box tabs"), id="top-cell", ) yield VerticalScroll( @@ -301,29 +390,44 @@ def action_exit(self) -> None: """ self.exit(0) - async def on_model(self, model): + async def on_model(self, model: List[str]) -> None: + """ + Callback for when clingo finds a model. + """ self.model_count += 1 - model = self.tree_cursor.add_leaf(f" MODEL {self.model_count} {' '.join([str(a) for a in model])}") - model.label.stylize("#000000 on #CCCCCC", 0, 7) - model.label.stylize("#000000 on #999999", 7, 7 + 2 + len(str(self.model_count))) + if self.tree_cursor is None: + return + model_node = self.tree_cursor.add_leaf(f" MODEL {self.model_count} {' '.join(model)}") + cast(Text, model_node.label).stylize("#000000 on #CCCCCC", 0, 7) + cast(Text, model_node.label).stylize("#000000 on #999999", 7, 7 + 2 + len(str(self.model_count))) # add some small sleep time to make ux seem more interactive await asyncio.sleep(0.1) - def on_propagate(self, decisions): + def on_propagate(self, decisions: List[Union[Decision, List[Decision]]]) -> None: + """ + Callback for the registered propagator does a propagate step. + """ + if self.tree_cursor is None: + return for element in decisions: if isinstance(element, list): for literal in element: - if literal.matches_any(self.loaded_signatures, show_internal=self.config_show_internal): + if literal.matches_any(self._loaded_signatures, show_internal=self._config_show_internal): entailment = self.tree_cursor.add_leaf(str(literal)).expand() - entailment.label.stylize("#666666") + cast(Text, entailment.label).stylize("#666666") else: new_node = self.tree_cursor.add(str(element)) new_node.expand() self.tree_cursor = new_node - def on_undo(self): + def on_undo(self) -> None: + """ + Callback for the registered propagator does an undo step. + """ + if self.tree_cursor is None: + return undo = self.tree_cursor.add_leaf(f"UNDO {self.tree_cursor.label}") - undo.label.stylize("#E53935") + cast(Text, undo.label).stylize("#E53935") self.tree_cursor = self.tree_cursor.parent async def action_update_config(self) -> None: @@ -331,9 +435,9 @@ async def action_update_config(self) -> None: Action to update the solving config """ # update model number - model_number_input = self.query_one("#model-number-input") + model_number_input = cast(Input, self.query_one("#model-number-input")) model_number = int(model_number_input.value) - self.config_model_number = model_number + self._config_model_number = model_number # update loaded files files_widget = self.query_one(FilesWidget) @@ -342,7 +446,7 @@ async def action_update_config(self) -> None: selector.apply_value_function() if selector.active: files.add(selector.value) - self.loaded_files = files + self._loaded_files = files # update program signatures signatures_widget = self.query_one(SignaturesWidget) @@ -352,14 +456,14 @@ async def action_update_config(self) -> None: if selector.active: signature_strings.add(selector.value) signatures = set() - self.config_show_internal = False + self._config_show_internal = False for signature_string in signature_strings: if signature_string.startswith(INTERNAL_STRING): - self.config_show_internal = True + self._config_show_internal = True else: name, arity = signature_string.split(" / ") signatures.add((name, int(arity))) - self.loaded_signatures = signatures + self._loaded_signatures = signatures async def action_solve(self) -> None: """ @@ -380,11 +484,11 @@ async def action_solve(self) -> None: callback_propagate=self.on_propagate, callback_undo=self.on_undo, ) - ctl = clingo.Control(f"{self.config_model_number}") + ctl = clingo.Control(f"{self._config_model_number}") ctl.register_propagator(sdp) - for file in self.loaded_files: + for file in self._loaded_files: ctl.load(file) - if not self.loaded_files: + if not self._loaded_files: ctl.add("base", [], "") ctl.ground([("base", [])]) @@ -397,7 +501,7 @@ async def action_solve(self) -> None: model = solver_handle.model() if model is None: break - await self.on_model(model.symbols(atoms=True)) + await self.on_model([str(a) for a in model.symbols(atoms=True)]) exhausted = result.exhausted if not exhausted: solver_handle.resume() @@ -409,31 +513,11 @@ async def action_solve(self) -> None: self.tree_cursor.add(end_string) -def flatten_list(ls: Optional[List[List[Any]]]) -> List: - if ls is None: - ls = [] - return list(itertools.chain.from_iterable(ls)) - - -def parse_constants(constant_strings: List[str]) -> Dict[str, str]: - constants = {} - for const_string in constant_strings: - result = re.search(r"(^[a-zA-Z_][a-zA-Z0-9_]*)=([a-zA-Z_][a-zA-Z0-9_]*|[0-9]+)$", const_string) - if result is not None: - constants[result.group(1)] = result.group(2) - return constants - - -def parse_signatures(signature_strings: List[str]) -> Set[Tuple[str, int]]: - signatures = set() - for signature_string in signature_strings: - result = re.search(r"^([a-zA-Z_][a-zA-Z0-9_]*)/([0-9]+)$", signature_string) - if result is not None: - signatures.add((result.group(1), int(result.group(2)))) - return signatures - +def textual_main() -> None: + """ + Main function for the clingo-explaid textual app. This function includes a dedicated ArgumentParser + """ -def textual_main(): parser = argparse.ArgumentParser(prog="clingexplaid", description="What the program does", epilog="Epilog Text") parser.add_argument( "files", @@ -463,6 +547,5 @@ def textual_main(): app = ClingexplaidTextualApp( files=list(set(flatten_list(args.files))), constants=parse_constants(flatten_list(args.const)), - signatures=parse_signatures(flatten_list(args.decision_signature)), ) app.run() diff --git a/src/clingexplaid/cli/textual_style.py b/src/clingexplaid/cli/textual_style.py index 5a0590b..6254be4 100644 --- a/src/clingexplaid/cli/textual_style.py +++ b/src/clingexplaid/cli/textual_style.py @@ -1,3 +1,7 @@ +""" +Module containing TCSS style strings for the textual TUI +""" + MAIN_CSS = """ Screen { layout: grid; diff --git a/src/clingexplaid/propagators/__init__.py b/src/clingexplaid/propagators/__init__.py index b665ff4..44d6273 100644 --- a/src/clingexplaid/propagators/__init__.py +++ b/src/clingexplaid/propagators/__init__.py @@ -6,13 +6,11 @@ from typing import List -from .propagator_decision_order import DecisionOrderPropagator from .propagator_solver_decisions import SolverDecisionPropagator DecisionLevel = List[int] DecisionLevelList = List[DecisionLevel] __all__ = [ - "DecisionOrderPropagator", "SolverDecisionPropagator", ] diff --git a/src/clingexplaid/propagators/propagator_decision_order.py b/src/clingexplaid/propagators/propagator_decision_order.py deleted file mode 100644 index 1e12f48..0000000 --- a/src/clingexplaid/propagators/propagator_decision_order.py +++ /dev/null @@ -1,176 +0,0 @@ -""" -Propagator Module: Decision Order -""" - -from typing import Dict, List, Optional, Sequence, Set, Tuple, Union - -import clingo - -from ..utils.logging import COLORS -from .constants import INDENT_END, INDENT_START, INDENT_STEP, UNKNOWN_SYMBOL_TOKEN - - -class DecisionOrderPropagator: - """ - Propagator for showing the Decision Order of clingo - """ - - def __init__(self, signatures: Optional[Set[Tuple[str, int]]] = None, prefix: str = ""): - # pylint: disable=missing-function-docstring - self.slit_symbol_lookup: Dict[int, clingo.Symbol] = {} - self.signatures = signatures if signatures is not None else set() - self.prefix = prefix - - self.last_decisions: List[int] = [] - self.last_entailments: Dict[int, List[int]] = {} - - def init(self, init: clingo.PropagateInit) -> None: - """ - Method to initialize the Decision Order Propagator. Here the literals are added to the Propagator's watch list. - """ - for atom in init.symbolic_atoms: - program_literal = atom.literal - solver_literal = init.solver_literal(program_literal) - self.slit_symbol_lookup[solver_literal] = atom.symbol - - for atom in init.symbolic_atoms: - if len(self.signatures) > 0 and not any(atom.match(name=s, arity=a) for s, a in self.signatures): - continue - symbolic_atom = init.symbolic_atoms[atom.symbol] - if symbolic_atom is None: - continue # nocoverage - query_program_literal = symbolic_atom.literal - query_solver_literal = init.solver_literal(query_program_literal) - init.add_watch(query_solver_literal) - init.add_watch(-query_solver_literal) - - def _is_printed(self, symbol: Union[clingo.Symbol, str]) -> bool: - """ - Helper function to check if a specific symbol should be printed or not - """ - printed = True - # skip UNKNOWN print if signatures is set - if len(self.signatures) > 0 and symbol == UNKNOWN_SYMBOL_TOKEN: - printed = False # nocoverage - # skip if symbol signature is not in self.signatures - elif len(self.signatures) > 0 and symbol != UNKNOWN_SYMBOL_TOKEN: - # `symbol` can only be a `str` if it is the UNKNOWN_SYMBOL_TOKEN - if isinstance(symbol, str): # nocoverage - printed = False - elif not any(symbol.match(s, a) for s, a in self.signatures): # nocoverage - printed = False - - return printed - - def propagate(self, control: clingo.PropagateControl, changes: Sequence[int]) -> None: - """ - Propagate method the is called when one the registered literals is propagated by clasp. Here useful information - about the decision progress is recorded to be visualized later. - """ - # pylint: disable=unused-argument - decisions, entailments = self.get_decisions(control.assignment) - - print_level = 0 - for d in decisions: - print_level += 1 - if d in self.last_decisions: - continue - - decision_symbol = self.get_symbol(d) - decision_printed = self._is_printed(decision_symbol) - decision_negative = d < 0 - - # build decision indent string - decision_indent_string = INDENT_START + INDENT_STEP * (print_level - 1) - # print decision if it matches the signatures (if provided) - if decision_printed: - print( - f"{self.prefix}{decision_indent_string}" - f"[{['+', '-'][int(decision_negative)]}]" - f" {decision_symbol} " - f"[{d}]" - ) - - entailment_list = entailments[d] if d in entailments else [] - # build entailment indent string - entailment_indent_string = ( - (INDENT_START + INDENT_STEP * (print_level - 2) + INDENT_END) if print_level > 1 else "│ " - ) - for e in entailment_list: - # skip decision in entailments - if e == d: - continue # nocoverage - entailment_symbol = self.get_symbol(e) - entailment_printed = self._is_printed(entailment_symbol) - # skip if entailment symbol doesn't mach signatures (if provided) - if not entailment_printed: - continue # nocoverage - - entailment_negative = e < 0 - if decision_printed: - print( - f"{self.prefix}{entailment_indent_string}{COLORS['GREY']}" - f"[{['+', '-'][int(entailment_negative)]}] " - f"{entailment_symbol} " - f"[{e}]{COLORS['NORMAL']}" - ) - - self.last_decisions = decisions - self.last_entailments = entailments - - def undo(self, thread_id: int, assignment: clingo.Assignment, changes: Sequence[int]) -> None: - """ - This function is called when one of the solvers decisions is undone. - """ - # pylint: disable=unused-argument - - if len(self.last_decisions) < 1: - return # nocoverage - decision = self.last_decisions[-1] - decision_symbol = self.get_symbol(decision) - - # don't print decision undo if its signature is not matching the provided ones - printed = self._is_printed(decision_symbol) - - indent_string = INDENT_START + INDENT_STEP * (len(self.last_decisions) - 1) - if printed: - print(f"{self.prefix}{indent_string}{COLORS['RED']}[✕] {decision_symbol} [{decision}]{COLORS['NORMAL']}") - self.last_decisions = self.last_decisions[:-1] - - @staticmethod - def get_decisions(assignment: clingo.Assignment) -> Tuple[List[int], Dict[int, List[int]]]: - """ - Helper function to extract a list of decisions and entailments from a clingo propagator assignment. - """ - level = 0 - decisions = [] - entailments = {} - try: - while True: - decision = assignment.decision(level) - decisions.append(decision) - - trail = assignment.trail - level_offset_start = trail.begin(level) - level_offset_end = trail.end(level) - level_offset_diff = level_offset_end - level_offset_start - if level_offset_diff > 1: - entailments[decision] = trail[(level_offset_start + 1) : level_offset_end] - level += 1 - except RuntimeError: - return decisions, entailments - - def get_symbol(self, literal: int) -> Union[clingo.Symbol, str]: - """ - Helper function to get a literal's associated symbol. - """ - try: - if literal > 0: - symbol = self.slit_symbol_lookup[literal] - else: - # negate symbol - symbol = clingo.parse_term(str(self.slit_symbol_lookup[-literal])) - except KeyError: - # internal literals - return UNKNOWN_SYMBOL_TOKEN - return symbol diff --git a/src/clingexplaid/propagators/propagator_solver_decisions.py b/src/clingexplaid/propagators/propagator_solver_decisions.py index 389fa85..9641a1b 100644 --- a/src/clingexplaid/propagators/propagator_solver_decisions.py +++ b/src/clingexplaid/propagators/propagator_solver_decisions.py @@ -15,11 +15,19 @@ @dataclass class Decision: + """ + Dataclass representing a solver decision + """ + positive: bool literal: int symbol: Optional[clingo.Symbol] def matches_any(self, signatures: Set[Tuple[str, int]], show_internal: bool = True) -> bool: + """ + Checks if the decisions symbol matches any of the provided `signatures`. If the decisions is an internal + literal `show_internal` is returned. + """ if self.symbol is not None: for sig, arity in signatures: if self.symbol.match(sig, arity): @@ -43,17 +51,19 @@ class SolverDecisionPropagator(Propagator): def __init__( self, signatures: Optional[Set[Tuple[str, int]]] = None, - callback_propagate: Optional[Callable] = None, - callback_undo: Optional[Callable] = None, + callback_propagate: Optional[Callable[[List[Union[Decision, List[Decision]]]], None]] = None, + callback_undo: Optional[Callable[[], None]] = None, ): # pylint: disable=missing-function-docstring self.literal_symbol_lookup: Dict[int, clingo.Symbol] = {} self.signatures = signatures if signatures is not None else set() - self.callback_propagate: Callable = callback_propagate if callback_propagate is not None else lambda x: None - self.callback_undo: Callable = callback_undo if callback_undo is not None else lambda x: None + self.callback_propagate: Callable[[List[Union[Decision, List[Decision]]]], None] = ( + callback_propagate if callback_propagate is not None else lambda x: None + ) + self.callback_undo: Callable[[], None] = callback_undo if callback_undo is not None else lambda: None - self.last_decisions: List[Decision] = [] + self.last_decisions: List[Union[Decision, List[Decision]]] = [] def init(self, init: clingo.PropagateInit) -> None: """ @@ -83,21 +93,22 @@ def propagate(self, control: clingo.PropagateControl, changes: Sequence[int], us # pylint: disable=unused-argument decisions, entailments = self.get_decisions(control.assignment) - literal_sequence = [] + literal_sequence: List[Union[int, List[int]]] = [] for d in decisions: literal_sequence.append(d) if d in entailments: literal_sequence.append(list(entailments[d])) + decision_sequence = self.literal_to_decision_sequence(literal_sequence) if use_diff: - decision_diff = [] - for i in range(len(decision_sequence)): + decision_diff: List[Union[Decision, List[Decision]]] = [] + for i, decision in enumerate(decision_sequence): if i < len(self.last_decisions): - if self.last_decisions[i] != decision_sequence[i]: - decision_diff.append(decision_sequence[i]) + if self.last_decisions[i] != decision: + decision_diff.append(decision) else: - decision_diff.append(decision_sequence[i]) + decision_diff.append(decision) self.last_decisions = decision_sequence self.callback_propagate(decision_diff) else: @@ -112,6 +123,9 @@ def undo(self, thread_id: int, assignment: clingo.Assignment, changes: Sequence[ self.callback_undo() def literal_to_decision(self, literal: int) -> Decision: + """ + Converts a literal integer to a `Decision` object. + """ is_positive = literal >= 0 symbol = self.literal_symbol_lookup.get(abs(literal)) return Decision(literal=abs(literal), positive=is_positive, symbol=symbol) @@ -119,7 +133,11 @@ def literal_to_decision(self, literal: int) -> Decision: def literal_to_decision_sequence( self, literal_sequence: List[Union[int, List[int]]] ) -> List[Union[Decision, List[Decision]]]: - new_decision_sequence = [] + """ + Converts a literal sequence into a decision sequence. These sequences are made up of their respective types or + lists of these types. + """ + new_decision_sequence: List[Union[Decision, List[Decision]]] = [] for element in literal_sequence: if isinstance(element, int): new_decision_sequence.append(self.literal_to_decision(element))