Skip to content

Commit

Permalink
refactor: small cleanup (#157)
Browse files Browse the repository at this point in the history
* refactor: small cleanup

Signed-off-by: nstarman <[email protected]>

* feat: Dispatcher is a dataclass

Signed-off-by: nstarman <[email protected]>

* fix: union

Signed-off-by: nstarman <[email protected]>

---------

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Jun 11, 2024
1 parent 6dcb9f1 commit 6de73bf
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions plum/dispatcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Dict, Optional, Tuple, TypeVar, Union
import sys
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, Optional, Tuple, TypeVar, Union, overload

from .function import Function
from .overload import get_overloads
Expand All @@ -10,6 +13,12 @@
T = TypeVar("T", bound=Callable[..., Any])


_dataclass_kwargs: Dict[str, Any] = {}
if sys.version_info >= (3, 10):
_dataclass_kwargs |= {"slots": True}


@dataclass(frozen=True, **_dataclass_kwargs)
class Dispatcher:
"""A namespace for functions.
Expand All @@ -19,11 +28,18 @@ class Dispatcher:
all classes by the qualified name of a class.
"""

def __init__(self):
self.functions: Dict[str, Function] = {}
self.classes: Dict[str, Dict[str, Function]] = {}
functions: Dict[str, Function] = field(default_factory=dict)
classes: Dict[str, Dict[str, Function]] = field(default_factory=dict)

@overload
def __call__(self, method: T, precedence: int = ...) -> T: ...

@overload
def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ...

def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T:
def __call__(
self, method: Optional[T] = None, precedence: int = 0
) -> Union[T, Callable[[T], T]]:
"""Decorator to register for a particular signature.
Args:
Expand All @@ -33,7 +49,7 @@ def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T:
function: Decorator.
"""
if method is None:
return lambda m: self(m, precedence=precedence)
return partial(self.__call__, precedence=precedence)

# If `method` has overloads, assume that those overloads need to be registered
# and that `method` is not an implementation.
Expand Down

0 comments on commit 6de73bf

Please sign in to comment.