From 6de73bf44e8d7ecdcd6d4550a34b9d6f64a4d088 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Tue, 11 Jun 2024 12:45:45 -0400 Subject: [PATCH] refactor: small cleanup (#157) * refactor: small cleanup Signed-off-by: nstarman * feat: Dispatcher is a dataclass Signed-off-by: nstarman * fix: union Signed-off-by: nstarman --------- Signed-off-by: nstarman --- plum/dispatcher.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/plum/dispatcher.py b/plum/dispatcher.py index 0f2b5adc..394b6bf1 100644 --- a/plum/dispatcher.py +++ b/plum/dispatcher.py @@ -1,4 +1,7 @@ -from typing import Any, Dict, Optional, Tuple, TypeVar, Union +import sys +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Dict, Optional, Tuple, TypeVar, Union, overload from .function import Function from .overload import get_overloads @@ -10,6 +13,12 @@ T = TypeVar("T", bound=Callable[..., Any]) +_dataclass_kwargs: Dict[str, Any] = {} +if sys.version_info >= (3, 10): + _dataclass_kwargs |= {"slots": True} + + +@dataclass(frozen=True, **_dataclass_kwargs) class Dispatcher: """A namespace for functions. @@ -19,11 +28,18 @@ class Dispatcher: all classes by the qualified name of a class. """ - def __init__(self): - self.functions: Dict[str, Function] = {} - self.classes: Dict[str, Dict[str, Function]] = {} + functions: Dict[str, Function] = field(default_factory=dict) + classes: Dict[str, Dict[str, Function]] = field(default_factory=dict) + + @overload + def __call__(self, method: T, precedence: int = ...) -> T: ... + + @overload + def __call__(self, method: None, precedence: int) -> Callable[[T], T]: ... - def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T: + def __call__( + self, method: Optional[T] = None, precedence: int = 0 + ) -> Union[T, Callable[[T], T]]: """Decorator to register for a particular signature. Args: @@ -33,7 +49,7 @@ def __call__(self, method: Optional[T] = None, precedence: int = 0) -> T: function: Decorator. """ if method is None: - return lambda m: self(m, precedence=precedence) + return partial(self.__call__, precedence=precedence) # If `method` has overloads, assume that those overloads need to be registered # and that `method` is not an implementation.