From 0ada36f0a9220c8ffe32992c2b8cc57b8ae06de1 Mon Sep 17 00:00:00 2001 From: Filippo Vicentini Date: Sat, 30 Sep 2023 19:04:46 +0200 Subject: [PATCH] Update function to use methods --- plum/function.py | 30 +++++++++++++++--------------- tests/test_function.py | 17 ++++++++++++----- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/plum/function.py b/plum/function.py index fd53c65a..93a58c32 100644 --- a/plum/function.py +++ b/plum/function.py @@ -5,8 +5,9 @@ from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union from copy import copy +from .method import Method from .resolver import AmbiguousLookupError, NotFoundLookupError, Resolver -from .signature import Signature, append_default_args, extract_signature +from .signature import Signature, append_default_args from .type import resolve_type_hint from .util import TypeHint, repr_short @@ -183,7 +184,7 @@ def __doc__(self, value: str) -> None: def methods(self) -> List[Signature]: """list[:class:`.signature.Signature`]: All available methods.""" self._resolve_pending_registrations() - return self._resolver.signatures + return self._resolver.methods def dispatch( self: Self, method: Optional[Callable] = None, precedence=0 @@ -278,24 +279,23 @@ def _resolve_pending_registrations(self) -> None: # Obtain the signature if it is not available. if signature is None: - signature = extract_signature(f, precedence=precedence) + signature = Signature.from_callable(f, precedence=precedence) else: # Ensure that the implementation is `f`, but make a copy before # mutating. signature = copy(signature) - signature.implementation = f # Ensure that the implementation has the right name, because this name # will show up in the docstring. - if getattr(signature.implementation, "__name__", None) != self.__name__: - signature.implementation = _change_function_name( - signature.implementation, - self.__name__, - ) + if getattr(f, "__name__", None) != self.__name__: + f_renamed = _change_function_name(f, self.__name__) + else: + f_renamed = f # Process default values. for subsignature in append_default_args(signature, f): - self._resolver.register(subsignature) + submethod = Method(f_renamed, subsignature, function_name=self.__name__) + self._resolver.register(submethod) registered = True if registered: @@ -339,18 +339,18 @@ def resolve_method( try: # Attempt to find the method using the resolver. - signature = self._resolver.resolve(target) - method = signature.implementation - return_type = signature.return_type + method = self._resolver.resolve(target) + impl = method.implementation + return_type = method.return_type except AmbiguousLookupError as e: raise self._enhance_exception(e) # Specify this function. except NotFoundLookupError as e: e = self._enhance_exception(e) # Specify this function. - method, return_type = self._handle_not_found_lookup_error(e) + impl, return_type = self._handle_not_found_lookup_error(e) - return method, return_type + return impl, return_type def _handle_not_found_lookup_error( self, ex: NotFoundLookupError diff --git a/tests/test_function.py b/tests/test_function.py index 234875d6..650a462b 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -6,6 +6,7 @@ import pytest from plum import Dispatcher +from plum.method import Method from plum.function import Function, _change_function_name, _convert, _owner_transfer from plum.resolver import AmbiguousLookupError, NotFoundLookupError from plum.signature import Signature @@ -250,15 +251,21 @@ def f(x: str): def test_methods(): dispatch = Dispatcher() - @dispatch def f(x: int): pass - @dispatch + method1 = Method(f, Signature(int), function_name="f") + f_dispatch = dispatch(f) + def f(x: float): pass - assert f.methods == [Signature(int), Signature(float)] + method2 = Method(f, Signature(float), function_name="f") + dispatch(f) + + methods = [method1, method2] + + assert f_dispatch.methods == methods def test_function_dispatch(): @@ -279,7 +286,7 @@ def other_implementation(x: str): assert f(1) == "int" assert f(1.0) == "float" assert f("1") == "str" - assert f._resolver.resolve(("1",)).precedence == 1 + assert f._resolver.resolve(("1",)).signature.precedence == 1 def test_function_multi_dispatch(): @@ -296,7 +303,7 @@ def implementation(x): assert f(1) == "int" assert f(1.0) == "float or str" assert f("1") == "float or str" - assert f._resolver.resolve(("1",)).precedence == 1 + assert f._resolver.resolve(("1",)).signature.precedence == 1 # Check that arguments to `f.dispatch_multi` must be tuples or signatures. with pytest.raises(ValueError):