Skip to content

Commit

Permalink
Update function to use methods
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipVinc committed Sep 30, 2023
1 parent f17def6 commit 0ada36f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 20 deletions.
30 changes: 15 additions & 15 deletions plum/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union
from copy import copy

from .method import Method
from .resolver import AmbiguousLookupError, NotFoundLookupError, Resolver
from .signature import Signature, append_default_args, extract_signature
from .signature import Signature, append_default_args
from .type import resolve_type_hint
from .util import TypeHint, repr_short

Expand Down Expand Up @@ -183,7 +184,7 @@ def __doc__(self, value: str) -> None:
def methods(self) -> List[Signature]:
"""list[:class:`.signature.Signature`]: All available methods."""
self._resolve_pending_registrations()
return self._resolver.signatures
return self._resolver.methods

def dispatch(
self: Self, method: Optional[Callable] = None, precedence=0
Expand Down Expand Up @@ -278,24 +279,23 @@ def _resolve_pending_registrations(self) -> None:

# Obtain the signature if it is not available.
if signature is None:
signature = extract_signature(f, precedence=precedence)
signature = Signature.from_callable(f, precedence=precedence)
else:
# Ensure that the implementation is `f`, but make a copy before
# mutating.
signature = copy(signature)
signature.implementation = f

# Ensure that the implementation has the right name, because this name
# will show up in the docstring.
if getattr(signature.implementation, "__name__", None) != self.__name__:
signature.implementation = _change_function_name(
signature.implementation,
self.__name__,
)
if getattr(f, "__name__", None) != self.__name__:
f_renamed = _change_function_name(f, self.__name__)
else:
f_renamed = f

# Process default values.
for subsignature in append_default_args(signature, f):
self._resolver.register(subsignature)
submethod = Method(f_renamed, subsignature, function_name=self.__name__)
self._resolver.register(submethod)
registered = True

if registered:
Expand Down Expand Up @@ -339,18 +339,18 @@ def resolve_method(

try:
# Attempt to find the method using the resolver.
signature = self._resolver.resolve(target)
method = signature.implementation
return_type = signature.return_type
method = self._resolver.resolve(target)
impl = method.implementation
return_type = method.return_type

except AmbiguousLookupError as e:
raise self._enhance_exception(e) # Specify this function.

except NotFoundLookupError as e:
e = self._enhance_exception(e) # Specify this function.
method, return_type = self._handle_not_found_lookup_error(e)
impl, return_type = self._handle_not_found_lookup_error(e)

return method, return_type
return impl, return_type

def _handle_not_found_lookup_error(
self, ex: NotFoundLookupError
Expand Down
17 changes: 12 additions & 5 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from plum import Dispatcher
from plum.method import Method
from plum.function import Function, _change_function_name, _convert, _owner_transfer
from plum.resolver import AmbiguousLookupError, NotFoundLookupError
from plum.signature import Signature
Expand Down Expand Up @@ -250,15 +251,21 @@ def f(x: str):
def test_methods():
dispatch = Dispatcher()

@dispatch
def f(x: int):
pass

@dispatch
method1 = Method(f, Signature(int), function_name="f")
f_dispatch = dispatch(f)

def f(x: float):
pass

assert f.methods == [Signature(int), Signature(float)]
method2 = Method(f, Signature(float), function_name="f")
dispatch(f)

methods = [method1, method2]

assert f_dispatch.methods == methods


def test_function_dispatch():
Expand All @@ -279,7 +286,7 @@ def other_implementation(x: str):
assert f(1) == "int"
assert f(1.0) == "float"
assert f("1") == "str"
assert f._resolver.resolve(("1",)).precedence == 1
assert f._resolver.resolve(("1",)).signature.precedence == 1


def test_function_multi_dispatch():
Expand All @@ -296,7 +303,7 @@ def implementation(x):
assert f(1) == "int"
assert f(1.0) == "float or str"
assert f("1") == "float or str"
assert f._resolver.resolve(("1",)).precedence == 1
assert f._resolver.resolve(("1",)).signature.precedence == 1

# Check that arguments to `f.dispatch_multi` must be tuples or signatures.
with pytest.raises(ValueError):
Expand Down

0 comments on commit 0ada36f

Please sign in to comment.