From 413c5279e261531d80f597a106f059a310988300 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Fri, 29 Sep 2023 19:21:58 +0200 Subject: [PATCH] more stuff --- plum/method.py | 103 +++++++++++++++++++++++++++++++++++++--- plum/resolver.py | 45 ++++++++++++++---- plum/signature.py | 31 ++++++++++++ tests/test_signature.py | 2 +- 4 files changed, 165 insertions(+), 16 deletions(-) diff --git a/plum/method.py b/plum/method.py index 5a00839b..d1e79a59 100644 --- a/plum/method.py +++ b/plum/method.py @@ -56,9 +56,9 @@ def __repr__(self): parts = [] if sig.types: for nm, t in zip(argnames, sig.types): - parts.append(f"{nm}: {repr_short(t)}") + parts.append(f"{nm}: {_repr_type(t)}") if sig.varargs != Signature._default_varargs: - parts.append(f"*{argnames[-1]}: {repr_short(sig.varargs)}") + parts.append(f"*{argnames[-1]}: {_repr_type(sig.varargs)}") if len(kwnames) > 0 or kwvar_name is not None: parts.append("*") @@ -69,19 +69,108 @@ def __repr__(self): res = f"{self.function_name}(" + ", ".join(parts) + ")" if self.return_type != Method._default_return_type: - res += f" -> {repr_short(self.return_type)}" + res += f" -> {_repr_type(self.return_type)}" if sig.precedence != Signature._default_precedence: res += "\n\tprecedence=" + repr(sig.precedence) - res += "\n\t" + repr(self.implementation) + res += "\n\t" + self._repr_method_namepath() + + return res + + def _repr_method_namepath(self): + res = repr(self.implementation) try: - res += f" @ {inspect.getfile(self.implementation)}" - res += ":" + str(inspect.getsourcelines(self.implementation)[1]) + fpath = inspect.getfile(self.implementation) + fline = str(inspect.getsourcelines(self.implementation)[1]) + uri = "file://" + fpath + "#" + fline + + import os + + # compress the path + home_path = os.path.expanduser("~") + fpath = fpath.replace(home_path, "~") + + # underline file name + fname = os.path.basename(fpath) + if fname.endswith(".py"): + fpath = fpath.replace( + fname, _colored(_colored(fname, color.BOLD), color.UNDERLINE) + ) + fpath = fpath + ":" + fline + + res += " @ " + _link(uri, fpath) except OSError: - pass + res = "" + return res + + def _repr_signature_mismatch(self, args_ok): + sig = self.signature + + argnames, kwnames, kwvar_name = extract_argnames(self.implementation) + varargs_ok = all(args_ok[len(sig.types) :]) + + parts = [] + if sig.types: + for nm, t, is_ok in zip(argnames, sig.types, args_ok): + clr = (color.RED,) if not is_ok else tuple() + parts.append(f"{nm}: {_repr_type(t, *clr)}") + if sig.varargs != Signature._default_varargs: + clr = (color.RED,) if not varargs_ok else tuple() + parts.append(f"*{argnames[-1]}: {_repr_type(sig.varargs, *clr)}") + + if len(kwnames) > 0 or kwvar_name is not None: + parts.append("*") + for kwnm in kwnames: + parts.append(f"{kwnm}") + if kwvar_name is not None: + parts.append(f"**{kwvar_name}") + + res = f"{self.function_name}(" + ", ".join(parts) + ")" + if self.return_type != Method._default_return_type: + res += f" -> {_repr_type(self.return_type)}" + if sig.precedence != Signature._default_precedence: + res += "\n\tprecedence=" + repr(sig.precedence) + + res += "\n\t" + self._repr_method_namepath() + return res +class color: + PURPLE = "\033[95m" + CYAN = "\033[96m" + DARKCYAN = "\033[36m" + BLUE = "\033[94m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + RED = "\033[91m" + GRAY = "\033[90m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" + END = "\033[0m" + + +def _colored(renderable, *clr): + if not isinstance(renderable, str): + renderable = repr(renderable) + return "".join(clr) + renderable + color.END + + +def _link(uri, label=None): + if label is None: + label = uri + parameters = "" + + # OSC 8 ; params ; URI ST OSC 8 ;; ST + escape_mask = "\033]8;{};{}\033\\{}\033]8;;\033\\" + + return escape_mask.format(parameters, uri, label) + + +def _repr_type(typ, *clrs): + return _colored(repr_short(typ), color.BOLD, *clrs) + + def extract_argnames(f: Callable, precedence: int = 0) -> Signature: """Extract the signature from a function. diff --git a/plum/resolver.py b/plum/resolver.py index 3c1125ee..87eabc90 100644 --- a/plum/resolver.py +++ b/plum/resolver.py @@ -171,7 +171,8 @@ def check(m): if len(candidates) == 0: # There is no matching signature. - raise NotFoundLookupError(f"`{target}` could not be resolved.") + msg = self._resolution_error_hint(target) + raise NotFoundLookupError(f"`{target}` could not be resolved.\n" + msg) elif len(candidates) == 1: # There is exactly one matching signature. Success! return candidates[0] @@ -185,10 +186,38 @@ def check(m): else: # Could not resolve the ambiguity, so error. First, make a nice list # of the candidates and their precedences. - listed_candidates = "\n ".join( - [f"{c} (precedence: {c.precedence})" for c in candidates] - ) - raise AmbiguousLookupError( - f"`{target}` is ambiguous among the following:\n" - f" {listed_candidates}" - ) + msg = f"`msg{target}` is ambiguous among the following:\n" + for i, c in enumerate(candidates): + msg += f"\n [{i}] {c}" + + raise AmbiguousLookupError(msg) + + def _resolution_error_hint( + self, + target: Union[Tuple[object, ...], Signature], + ): + distances = [] + for method in self.methods: + dist = method.signature.compute_distance(target) + distances.append(dist) + + sort_method_ids = _argsort(distances) + + # Take at most 3 hints + sort_method_ids = sort_method_ids[:3] + + distances = [distances[i] for i in sort_method_ids] + methods = [self.methods[i] for i in sort_method_ids] + + # create the error message + fname = self.methods[0].function_name if len(self.methods) > 0 else "???" + msg = f"No method matching {fname}({target})\n\nClosest candidates are:\n" + for m in methods: + args_ok = m.signature.compute_args_ok(target) + msg += m._repr_signature_mismatch(args_ok) + "\n" + + return msg + + +def _argsort(iterable): + return sorted(range(len(iterable)), key=iterable.__getitem__) diff --git a/plum/signature.py b/plum/signature.py index ab91bb11..b19cc5ca 100644 --- a/plum/signature.py +++ b/plum/signature.py @@ -154,6 +154,37 @@ def match(self, values) -> bool: types = self.expand_varargs(len(values)) return all(_is_bearable(v, t) for v, t in zip(values, types)) + def compute_distance(self, values) -> int: + types = self.expand_varargs(len(values)) + # vararg_types = types[len(self.types):] + + distance = 0 + + # count 1 for every extra or missingargument + distance += abs(len(types) - len(values)) + + # count 1 for every mismatching arg type + for v, t in zip(values, types): + if not _is_bearable(v, t): + distance += 1 + + return distance + + def compute_args_ok(self, values) -> list[bool]: + types = self.expand_varargs(len(values)) + + args_ok = [] + + # count 1 for every mismatching arg type + for v, t in zip(values, types): + args_ok.append(_is_bearable(v, t)) + + # all extra args are not ok + for _ in range(len(args_ok), len(values)): + args_ok.append(False) + + return args_ok + def extract_signature(f: Callable, precedence: int = 0) -> Signature: """Extract the signature from a function. diff --git a/tests/test_signature.py b/tests/test_signature.py index b60921d6..c3a6794a 100644 --- a/tests/test_signature.py +++ b/tests/test_signature.py @@ -85,7 +85,7 @@ def _impl(x, y, *z): varargs=complex, precedence=1, ), - f"Signature(int, float, varargs=complex, precedence=1)" + f"Signature(int, float, varargs=complex, precedence=1)", ), ], )