From bcd6cfe780adba96b8083e32095b40a7566dda1b Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 27 Feb 2023 16:11:30 -0800 Subject: [PATCH] can nickname any func now --- src/springs/core.py | 17 ++++---- src/springs/initialize.py | 21 +++++++--- src/springs/nicknames.py | 84 +++++++++++++++++++++++++-------------- src/springs/shortcuts.py | 8 ++-- tests/test_create.py | 15 +++---- tests/test_nicknames.py | 56 ++++++++++++++++++++++++++ 6 files changed, 145 insertions(+), 56 deletions(-) diff --git a/src/springs/core.py b/src/springs/core.py index 858b42c..920c654 100644 --- a/src/springs/core.py +++ b/src/springs/core.py @@ -5,16 +5,15 @@ from functools import reduce from inspect import isclass from pathlib import Path -from typing import ( - Any, Dict, List, Sequence, Tuple, TypeVar, Union, Callable, - cast as typing_cast, overload -) +from typing import Any, Callable, Dict, List, Sequence, Tuple, TypeVar, Union +from typing import cast as typing_cast +from typing import overload from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.errors import MissingMandatoryValue from omegaconf.omegaconf import DictKeyType +from typing_extensions import Concatenate, ParamSpec from yaml.scanner import ScannerError -from typing_extensions import ParamSpec, Concatenate from .flexyclasses import FlexyClass from .traversal import FailedParamSpec, traverse @@ -25,7 +24,7 @@ T = TypeVar("T") R = TypeVar("R") C = TypeVar("C", bound=Union[DictConfig, ListConfig]) -PS = ParamSpec('PS') +PS = ParamSpec("PS") def cast(config: Any, copy: bool = False) -> DictConfig: @@ -59,9 +58,7 @@ def _from_allow_none_or_skip( or returns the input if it is already a DictConfig or ListConfig""" def wrapped( - config: Union[T, None], - *args: PS.args, - **kwargs: PS.kwargs + config: Union[T, None], *args: PS.args, **kwargs: PS.kwargs ) -> R: if config is None: return from_none() @@ -101,7 +98,7 @@ def from_python(config: List[Any]) -> ListConfig: ... -@_from_allow_none_or_skip # type: ignore +@_from_allow_none_or_skip # type: ignore def from_python( config: Union[Dict[DictKeyType, Any], Dict[str, Any], List[Any]] ) -> Union[DictConfig, ListConfig]: diff --git a/src/springs/initialize.py b/src/springs/initialize.py index f44b067..711fcda 100644 --- a/src/springs/initialize.py +++ b/src/springs/initialize.py @@ -385,8 +385,8 @@ def _find_child_type( initializer, this function figures out the expected type using a combination of class annotations and __init__ annotations. - If the type the attribute should be cannot be determined, it - simply returns None. + If the type the attribute cannot be determined, it simply returns + None. An example: @@ -418,12 +418,21 @@ def __init__(self, a): return None if attr_name in (parent_anns := get_annotations(cls_)): + # this is good for cases where the attribute is annotated + # in the class definition return parent_anns[attr_name] - if attr_name in ( - parent_anns := inspect.getfullargspec(cls_).annotations - ): - return parent_anns[attr_name] + try: + # this fails on C types; we give up on nested initialization + # on those. + spec = inspect.getfullargspec(cls_) + except TypeError: + return None + + if attr_name in spec.annotations: + # this is good for cases where the attribute is annotated + # in a function definition + return spec.annotations[attr_name] # this is the case where we cannot resolve anything return None diff --git a/src/springs/nicknames.py b/src/springs/nicknames.py index 3de1554..6f72b47 100644 --- a/src/springs/nicknames.py +++ b/src/springs/nicknames.py @@ -1,8 +1,8 @@ +import inspect from dataclasses import is_dataclass from inspect import isclass from pathlib import Path from typing import ( - Any, Callable, Dict, Literal, @@ -13,23 +13,23 @@ Type, TypeVar, Union, - cast, overload, ) -from typing_extensions import ParamSpec -from omegaconf import DictConfig, ListConfig +from omegaconf import MISSING, DictConfig, ListConfig +from typing_extensions import ParamSpec -from .core import from_file +from .core import from_dict, from_file from .flexyclasses import FlexyClass from .logging import configure_logging -RegistryValue = Union[Type[Any], Type[FlexyClass], DictConfig, ListConfig] +RegistryValue = Union[Callable, Type[FlexyClass], DictConfig, ListConfig] +# M = TypeVar("M", bound=RegistryValue) T = TypeVar("T") -M = TypeVar("M", bound=RegistryValue) P = ParamSpec("P") + LOGGER = configure_logging(__name__) @@ -87,38 +87,62 @@ def scan( LOGGER.warning(f"Could not load config from {path}") @classmethod - def _add(cls, name: str, config: M) -> M: + def _add(cls, name: str, config: RegistryValue) -> RegistryValue: cls.__registry__[name] = config return config - # @overload - # @classmethod - # def add(cls, name: str) -> Callable[[Type[T]], Type[T]]: - # ... + @overload + @classmethod + def add(cls, name: str) -> Callable[[Type[T]], Type[T]]: + ... - # @overload - # @classmethod - # def add(cls, name: str) -> Callable[[Callable[P, T]], Callable[P, T]]: - # ... + @overload + @classmethod + def add( # type: ignore + cls, name: str + ) -> Callable[[Callable[P, T]], Callable[P, T]]: + ... @classmethod - def add(cls, name: str) -> Callable[[Callable[P, T]], Callable[P, T]]: + def add( + cls, name: str + ) -> Union[ + Callable[[Type[T]], Type[T]], + Callable[[Callable[P, T]], Callable[P, T]], + ]: """Decorator to save a structured configuration with a nickname for easy reuse.""" - def add_to_registry(cls_: Callable[P, T]) -> Callable[P, T]: - if is_dataclass(cls_): - pass - elif isclass(cls_) and issubclass(cls_, FlexyClass): - pass + if name in cls.__registry__: + raise ValueError(f"{name} is already registered") + + def add_to_registry( + cls_or_fn: Union[Type[T], Callable[P, T]] + ) -> Union[Type[T], Callable[P, T]]: + if is_dataclass(cls_or_fn): + # Pylance complains about dataclasses not being a valid type, + # but the problem is DataclassInstance is only defined within + # Pylance, so I can't type annotate with that. + cls._add(name, cls_or_fn) # pyright: ignore + elif isclass(cls_or_fn) and issubclass(cls_or_fn, FlexyClass): + cls._add(name, cls_or_fn) else: - raise ValueError(f"{cls_} must be a dataclass") - - if name in cls.__registry__: - raise ValueError(f"{name} is already registered") - return cls._add(name, cls_) - - return add_to_registry + from .initialize import Target, init + + sig = inspect.signature(cls_or_fn) + entry = from_dict( + { + init.TARGET: Target.to_string(cls_or_fn), + **{ + k: (v.default if v.default != v.empty else MISSING) + for k, v in sig.parameters.items() + }, + } + ) + cls._add(name, entry) + return cls_or_fn # type: ignore + + return add_to_registry # type: ignore @overload @classmethod @@ -147,7 +171,7 @@ def all(cls) -> Sequence[Tuple[str, str]]: return [ ( name, - getattr(config, '__name__', type(config).__name__) + getattr(config, "__name__", type(config).__name__) if is_dataclass(config) else type(config).__name__, ) diff --git a/src/springs/shortcuts.py b/src/springs/shortcuts.py index 8c74aae..fd4b69c 100644 --- a/src/springs/shortcuts.py +++ b/src/springs/shortcuts.py @@ -9,11 +9,12 @@ Optional, Sequence, Set, - Type, TypeVar, Union, ) +from typing_extensions import ParamSpec + from .field_utils import field from .flexyclasses import flexyclass from .initialize import Target @@ -22,6 +23,7 @@ from .utils import SpringsConfig, SpringsWarnings T = TypeVar("T") +P = ParamSpec("P") def get_nickname( @@ -45,9 +47,9 @@ def make_target(c: Callable) -> str: return Target.to_string(c) -def nickname(name: str) -> Callable[[Type[T]], Type[T]]: +def nickname(name: str) -> Callable[[T], T]: """Shortcut for springs.nicknames.NicknameRegistry.add""" - return NicknameRegistry.add(name) + return NicknameRegistry.add(name) # type: ignore def scan( diff --git a/tests/test_create.py b/tests/test_create.py index 5800a48..d026e2d 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,13 +1,14 @@ -from tempfile import NamedTemporaryFile import unittest +from tempfile import NamedTemporaryFile from omegaconf import DictConfig, ListConfig + import springs as sp @sp.dataclass class DT: - foo: str = 'bar' + foo: str = "bar" class TestCreation(unittest.TestCase): @@ -19,7 +20,7 @@ def test_from_dict(self): self.assertEqual(sp.to_dict(sp.from_dict(DictConfig({}))), {}) def test_from_dataclass(self): - self.assertEqual(sp.to_dict(sp.from_dataclass(DT)), {'foo': 'bar'}) + self.assertEqual(sp.to_dict(sp.from_dataclass(DT)), {"foo": "bar"}) self.assertEqual(sp.to_dict(sp.from_dataclass(None)), {}) self.assertEqual(sp.to_dict(sp.from_dataclass(DictConfig({}))), {}) @@ -31,11 +32,11 @@ def test_from_python(self): sp.to_python(sp.from_python(None)), {} # type: ignore ) self.assertEqual( - sp.to_python(sp.from_python(ListConfig([]))), [] # type: ignore + sp.to_python(sp.from_python(ListConfig([]))), [] # type: ignore ) def test_from_file(self): - with NamedTemporaryFile('w') as f: - f.write('foo: bar') + with NamedTemporaryFile("w") as f: + f.write("foo: bar") f.flush() - self.assertEqual(sp.to_dict(sp.from_file(f.name)), {'foo': 'bar'}) + self.assertEqual(sp.to_dict(sp.from_file(f.name)), {"foo": "bar"}) diff --git a/tests/test_nicknames.py b/tests/test_nicknames.py index 73a4bde..8bc1556 100644 --- a/tests/test_nicknames.py +++ b/tests/test_nicknames.py @@ -22,6 +22,17 @@ class DevConfig: batch_size: int = 32 +@sp.nickname("class_nickname") +class NC: + def __init__(self) -> None: + self.foo = "bar" + + +@sp.nickname("function_nickname") +def nf(text: str = "bar"): + return text + + class TestNicknames(unittest.TestCase): def setUp(self) -> None: self.cfg = sp.from_dict( @@ -64,3 +75,48 @@ def test_dict_nicknames(self): mod2 = sp.get_nickname("test/temp/config") self.assertTrue(isinstance(mod2, DictConfig)) self.assertEqual(sp.to_python(mod), sp.to_python(mod2)) + + def test_class_nickname(self): + mod = sp.get_nickname("class_nickname") + self.assertEqual( + mod, sp.from_dict({"_target_": "tests.test_nicknames.NC"}) + ) + + obj = sp.init.now(mod, NC) + self.assertTrue(isinstance(obj, NC)) + self.assertEqual(obj.foo, "bar") + + def test_function_nickname(self): + mod = sp.get_nickname("function_nickname") + self.assertEqual( + mod, + sp.from_dict( + { + "_target_": "tests.test_nicknames.nf", + "text": "bar", + } + ), + ) + + obj = sp.init.now(mod, str) + self.assertEqual(obj, "bar") + + obj = sp.init.now(mod, str, text="foo") + self.assertEqual(obj, "foo") + + def test_nickname_from_resolver(self): + mod = sp.from_string("func: ${sp.ref:function_nickname}") + mod = sp.resolve(mod) + self.assertEqual( + mod, + sp.from_python( + { + "func": { + "_target_": "tests.test_nicknames.nf", + "text": "bar", + } + } + ), + ) + obj = sp.init.now(mod.func, str, text="foo") + self.assertEqual(obj, "foo")