diff --git a/dol/signatures.py b/dol/signatures.py index 2e5b7271..24f7f9d7 100644 --- a/dol/signatures.py +++ b/dol/signatures.py @@ -95,6 +95,9 @@ Iterator, TypeVar, Mapping as MappingType, + Literal, + Optional, + get_args, ) from typing import KT, VT, T from types import FunctionType @@ -154,6 +157,8 @@ def wrapper(*args, **kwargs): var_param_kind_dflts_items = tuple({VP: (), VK: {}}.items()) DFLT_DEFAULT_CONFLICT_METHOD = 'strict' +SigMergeOptions = Literal[None, 'strict', 'take_first', 'fill_defaults_and_annotations'] + param_attributes = {'name', 'kind', 'default', 'annotation'} @@ -180,6 +185,11 @@ class IncompatibleSignatures(ValueError): # raise AttributeError(f"module {__name__} has no attribute {name}") +def is_signature_error(e: BaseException) -> bool: + """Check if an exception is a signature error""" + return isinstance(e, ValueError) and 'no signature found' in str(e) + + def _param_sort_key(param): return (param.kind, param.kind == KO or param.default is not empty) @@ -705,7 +715,10 @@ def extract_arguments( if include_all_when_var_keywords_in_params: if ( - next((p.name for p in params if p.kind == Parameter.VAR_KEYWORD), None,) + next( + (p.name for p in params if p.kind == Parameter.VAR_KEYWORD), + None, + ) is not None ): param_kwargs.update(remaining_kwargs) @@ -1975,7 +1988,7 @@ def merge_with_sig( sig: ParamsAble, ch_to_all_pk: bool = False, *, - default_conflict_method: str = DFLT_DEFAULT_CONFLICT_METHOD, + default_conflict_method: SigMergeOptions = DFLT_DEFAULT_CONFLICT_METHOD, ): """Return a signature obtained by merging self signature with another signature. Insofar as it can, given the kind precedence rules, the arguments of self will @@ -2043,14 +2056,8 @@ def merge_with_sig( 'them with a signature mapping that avoids the argument name clashing' ) - assert default_conflict_method in { - None, - 'strict', - 'take_first', - 'fill_defaults_and_annotations', - }, ( - 'default_conflict_method should be in ' - "{None, 'strict', 'take_first', 'fill_defaults_and_annotations'}" + assert default_conflict_method in get_args(SigMergeOptions), ( + 'default_conflict_method should be one of: ' f"{get_args(SigMergeOptions)}" ) if default_conflict_method == 'take_first': @@ -2925,7 +2932,9 @@ def extract_args_and_kwargs( ignore_kind=_ignore_kind, ) return self.mk_args_and_kwargs( - arguments, allow_partial=_allow_partial, args_limit=_args_limit, + arguments, + allow_partial=_allow_partial, + args_limit=_args_limit, ) def source_arguments( @@ -3090,7 +3099,9 @@ def source_args_and_kwargs( **kwargs, ) return self.mk_args_and_kwargs( - arguments, allow_partial=_allow_partial, args_limit=_args_limit, + arguments, + allow_partial=_allow_partial, + args_limit=_args_limit, ) @@ -4227,32 +4238,23 @@ def zip(*iterables): zip(*iterables) --> A zip object yielding tuples until an input is exhausted. """ - def bool(x: Any, /) -> bool: - ... + def bool(x: Any, /) -> bool: ... - def bytearray(iterable_of_ints: Iterable[int], /): - ... + def bytearray(iterable_of_ints: Iterable[int], /): ... - def classmethod(function: Callable, /): - ... + def classmethod(function: Callable, /): ... - def int(x, base=10, /): - ... + def int(x, base=10, /): ... - def iter(callable: Callable, sentinel=None, /): - ... + def iter(callable: Callable, sentinel=None, /): ... - def next(iterator: Iterator, default=None, /): - ... + def next(iterator: Iterator, default=None, /): ... - def staticmethod(function: Callable, /): - ... + def staticmethod(function: Callable, /): ... - def str(bytes_or_buffer, encoding=None, errors=None, /): - ... + def str(bytes_or_buffer, encoding=None, errors=None, /): ... - def super(type_, obj=None, /): - ... + def super(type_, obj=None, /): ... # def type(name, bases=None, dict=None, /): # ... @@ -4422,14 +4424,11 @@ class sigs_for_type_name: signatures (through ``inspect.signature``), """ - def itemgetter(iterable: Iterable[VT], /) -> Union[VT, Tuple[VT]]: - ... + def itemgetter(iterable: Iterable[VT], /) -> Union[VT, Tuple[VT]]: ... - def attrgetter(iterable: Iterable[VT], /) -> Union[VT, Tuple[VT]]: - ... + def attrgetter(iterable: Iterable[VT], /) -> Union[VT, Tuple[VT]]: ... - def methodcaller(obj: Any) -> Any: - ... + def methodcaller(obj: Any) -> Any: ... ############# Tools for testing ######################################################### @@ -4474,7 +4473,9 @@ def param_for_kind( lower_kind = kind.lower() setattr(param_for_kind, lower_kind, partial(param_for_kind, kind=kind)) setattr( - param_for_kind, 'with_default', partial(param_for_kind, with_default=True), + param_for_kind, + 'with_default', + partial(param_for_kind, with_default=True), ) setattr( getattr(param_for_kind, lower_kind), @@ -4488,7 +4489,7 @@ def param_for_kind( ) ######################################################################################## -# Signature Compatibility # +# Signature Comparison and Compatibility # ######################################################################################## Compared = TypeVar('Compared') @@ -4528,7 +4529,10 @@ def mk_func_comparator_based_on_signature_comparator( def _keyed_comparator( - comparator: Comparator, key: KeyFunction, x: CT, y: CT, + comparator: Comparator, + key: KeyFunction, + x: CT, + y: CT, ) -> Comparison: """Apply a comparator after transforming inputs through a key function. @@ -4542,7 +4546,10 @@ def _keyed_comparator( return comparator(key(x), key(y)) -def keyed_comparator(comparator: Comparator, key: KeyFunction,) -> Comparator: +def keyed_comparator( + comparator: Comparator, + key: KeyFunction, +) -> Comparator: """Create a key-function enabled binary operator. In various places in python functionality is extended by allowing a key function. @@ -4625,59 +4632,135 @@ def param_comparator( ) +param_comparator: ParamComparator param_binary_func = param_comparator # back compatibility alias # TODO: Implement annotation compatibility -def is_annotation_compatible_with(annot1, annot2): +def ignore_any_differences(x, y): return True -def is_default_value_compatible_with(dflt1, dflt2): - return dflt1 is empty or dflt2 is not empty +permissive_param_comparator = partial( + param_comparator, + name=ignore_any_differences, + kind=ignore_any_differences, + default=ignore_any_differences, + annotation=ignore_any_differences, +) +permissive_param_comparator.__doc__ = """ +Permissive version of param_comparator that ignores any differences of parameter +attributes. +It is meant to be used with partial, but with a permissive base, contrary to the +base param_comparator which requires strict equality (`eq`) for all attributes. +""" -def is_param_compatible_with( - p1: Parameter, - p2: Parameter, - annotation_comparator: Comparator = None, - default_value_comparator: Comparator = None, -): - """Return True if ``p1`` is compatible with ``p2``. Meaning that any value valid - for ``p1`` is valid for ``p2``. - :param p1: The main parameter. - :param p2: The parameter to be compared with. - :param annotation_comparator: The function used to compare the annotations - :param default_value_comparator: The function used to compare the default values +def return_tuple(x, y): + return x, y - >>> is_param_compatible_with( - ... Parameter('a', PO), - ... Parameter('b', PO) - ... ) - True - >>> is_param_compatible_with( - ... Parameter('a', PO), - ... Parameter('b', PO, default=0) - ... ) - True - >>> is_param_compatible_with( - ... Parameter('a', PO, default=0), - ... Parameter('b', PO) - ... ) - False + +param_attribute_dict: ComparisonAggreg + + +def param_attribute_dict(name_kind_default_annotation: Iterable[Comparison]) -> dict: + keys = ['name', 'kind', 'default', 'annotation'] + return {key: value for key, value in zip(keys, name_kind_default_annotation)} + + +param_comparison_dict = partial( + param_comparator, + name=return_tuple, + kind=return_tuple, + default=return_tuple, + annotation=return_tuple, + aggreg=param_attribute_dict, +) + +param_comparison_dict.__doc__ = """ +A ParamComparator that returns a dictionary with pairs parameter attributes. + +>>> param1 = Sig('(a: int = 1)')['a'] +>>> param2 = Sig('(a: str = 2)')['a'] +>>> param_comparison_dict(param1, param2) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS +{'name': ('a', 'a'), 'kind': ..., 'default': (1, 2), 'annotation': (, )} +""" + + +def param_differences_dict( + param1: Parameter, + param2: Parameter, + *, + name: Comparator = eq, + kind: Comparator = eq, + default: Comparator = eq, + annotation: Comparator = eq, +): + """Makes a dictionary exibiting the differences between two parameters. + + >>> param1 = Sig('(a: int = 1)')['a'] + >>> param2 = Sig('(a: str = 2)')['a'] + >>> param_differences_dict(param1, param2) + {'default': (1, 2), 'annotation': (, )} + >>> param_differences_dict(param1, param2, default=lambda x, y: isinstance(x, type(y))) + {'annotation': (, )} """ - # TODO: Consider using functions as defaults instead of None - annotation_comparator = annotation_comparator or is_annotation_compatible_with - default_value_comparator = ( - default_value_comparator or is_default_value_compatible_with + equality_vector = param_comparator( + param1, + param2, + name=name, + kind=kind, + default=default, + annotation=annotation, + aggreg=tuple, ) + comparison_dict = param_comparison_dict(param1, param2) + return { + key: comparison_dict[key] + for key, equal in zip(comparison_dict, equality_vector) + if not equal + } + - return annotation_comparator( - p1.annotation, p2.annotation - ) and default_value_comparator(p1.default, p2.default) +def defaults_are_the_same_when_not_empty(dflt1, dflt2): + """ + Check if two defaults are the same when they are not empty. + + # >>> defaults_are_the_same_when_not_empty(1, 1) + # True + # >>> defaults_are_the_same_when_not_empty(1, 2) + # False + # >>> defaults_are_the_same_when_not_empty(1, None) + # False + # >>> defaults_are_the_same_when_not_empty(1, Parameter.empty) + # True + """ + return dflt1 is empty or dflt2 is empty or dflt1 == dflt2 +def dflt1_is_empty_or_dflt2_is_not(dflt1, dflt2): + """ + Why such a strange default comparison function? + + This is to be used as a default in is_call_compatible_with. + + Consider two functions func1 and func2 with a parameter p with default values + dflt1 and dflt2 respectively. + If dflt1 was not empty and dflt2 was, this would mean that func1 could be called + without specifying p, but func2 couldn't. + + So to avoid this situation, we use dflt1_is_empty_or_dflt2_is_not as the default + + """ + return dflt1 is empty or dflt2 is not empty + + +# TODO: It seems like param_comparator is really only used to compare parameters on defaults. +# This may be due to the fact that is_call_compatible_with was developed independently +# from the other general param_comparator functionality that was developed (see above) +# The code of is_call_compatible_with should be reviwed and refactored to use general +# tools. def is_call_compatible_with( sig1: Sig, sig2: Sig, *, param_comparator: ParamComparator = None ) -> bool: @@ -4809,8 +4892,7 @@ def validate_param_compatibility(): return False return True - # TODO: Consider putting is_param_compatible_with as default instead - param_comparator = param_comparator or is_param_compatible_with + param_comparator = param_comparator or dflt1_is_empty_or_dflt2_is_not pos1, pks1, vp1, kos1, vk1 = sig1.detail_names_by_kind() ps1 = pos1 + pks1 @@ -4835,3 +4917,148 @@ def validate_param_compatibility(): and validate_param_positions() and validate_param_compatibility() ) + + +from dataclasses import dataclass + +from functools import cached_property +from dataclasses import dataclass +from i2.signatures import ( + Sig, + SignatureAble, + is_call_compatible_with, + param_comparator, + ParamComparator, + ComparisonAggreg, + param_differences_dict, +) +from inspect import Parameter + + +@dataclass +class SigComparison: + """ + Class to compare two signatures. + + :param sig1: First signature or signature-able object. + :param sig2: Second signature or signature-able object. + """ + + sig1: Union[Callable, Sig] + sig2: Union[Callable, Sig] + + def __post_init__(self): + self.sig1 = Sig(self.sig1) + self.sig2 = Sig(self.sig2) + + @cached_property + def shared_names(self): + """ + List of names that are common to both signatures, in the order of sig1. + + >>> sig1 = Sig(lambda a, b, c: None) + >>> sig2 = Sig(lambda b, c, d: None) + >>> comp = SigComparison(sig1, sig2) + >>> comp.shared_names + ['b', 'c'] + """ + return [name for name in self.sig1.names if name in self.sig2.names] + + @cached_property + def names_missing_in_sig2(self): + """ + List of names that are in the sig1 signature but not in sig2. + + >>> sig1 = Sig(lambda a, b, c: None) + >>> sig2 = Sig(lambda b, c, d: None) + >>> comp = SigComparison(sig1, sig2) + >>> comp.names_missing_in_sig2 + ['a'] + """ + return [name for name in self.sig1.names if name not in self.sig2.names] + + @cached_property + def names_missing_in_sig1(self): + """ + List of names that are in the sig2 signature but not in sig1. + + >>> sig1 = Sig(lambda a, b, c: None) + >>> sig2 = Sig(lambda b, c, d: None) + >>> comp = SigComparison(sig1, sig2) + >>> comp.names_missing_in_sig1 + ['d'] + """ + return [name for name in self.sig2.names if name not in self.sig1.names] + + # TODO: Verify that the doctests are correct! + def are_call_compatible(self, param_comparator = None) -> bool: + """ + Check if the signatures are call-compatible. + + Returns True if sig1 can be used to call sig2 or vice versa. + + >>> sig1 = Sig(lambda a, b, c=3: None) + >>> sig2 = Sig(lambda a, b: None) + >>> comp = SigComparison(sig1, sig2) + >>> comp.are_call_compatible() + False + + >>> comp = SigComparison(sig2, sig1) + >>> comp.are_call_compatible() + True + """ + return is_call_compatible_with( + self.sig1, self.sig2, param_comparator=param_comparator + ) + + def param_comparison( + self, + comparator = param_comparator, + aggregation = all, + ) -> bool: + """ + Compare parameters between the two signatures using the provided comparator function. + + :param comparator: A function to compare two parameters. + :param aggregation: A function to aggregate the results of the comparisons. + :return: Boolean result of the aggregated comparisons. + + >>> sig1 = Sig('(a, b: int, c=3)') + >>> sig2 = Sig('(a, *, b=2, d=4)') + >>> comp = SigComparison(sig1, sig2) + >>> comp.param_comparison() + False + """ + results = [ + comparator(self.sig1.parameters[name], self.sig2.parameters[name]) + for name in self.shared_names + ] + return aggregation(results) + + def param_differences(self) -> dict: + """ + Get a dictionary of parameter differences between the two signatures. + + :return: A dictionary containing differences for each parameter. + + >>> sig1 = Sig('(a, b: int, c=3)') + >>> sig2 = Sig('(a, *, b=2, d=4)') + >>> comp = SigComparison(sig1, sig2) + >>> result = comp.param_differences() + >>> expected = { + ... 'a': {}, + ... 'b': { + ... 'kind': (Parameter.POSITIONAL_OR_KEYWORD, Parameter.KEYWORD_ONLY), + ... 'default': (Parameter.empty, 2), + ... 'annotation': (int, Parameter.empty), + ... }, + ... } + >>> result == expected + True + """ + return { + name: param_differences_dict( + self.sig1.parameters[name], self.sig2.parameters[name] + ) + for name in self.shared_names + } \ No newline at end of file