From 425c2e7abef571ef8e338eddcfb4ea0f8905912c Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 27 Feb 2023 14:12:08 -0800 Subject: [PATCH 1/3] . --- pyproject.toml | 2 +- src/springs/core.py | 41 +++++++++++++++++++++++++++++++++------ src/springs/nicknames.py | 30 +++++++++++++++++++--------- src/springs/rich_utils.py | 6 ++++++ 4 files changed, 63 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 214a39e..aecb996 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "springs" -version = "1.11.1" +version = "1.12" description = """\ A set of utilities to create and manage typed configuration files \ effectively, built on top of OmegaConf.\ diff --git a/src/springs/core.py b/src/springs/core.py index 0d2d309..858b42c 100644 --- a/src/springs/core.py +++ b/src/springs/core.py @@ -5,14 +5,16 @@ from functools import reduce from inspect import isclass from pathlib import Path -from typing import Any, Dict, List, Sequence, Tuple, TypeVar, Union -from typing import cast as typing_cast -from typing import overload +from typing import ( + Any, Dict, List, Sequence, Tuple, TypeVar, Union, Callable, + cast as typing_cast, overload +) from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.errors import MissingMandatoryValue from omegaconf.omegaconf import DictKeyType from yaml.scanner import ScannerError +from typing_extensions import ParamSpec, Concatenate from .flexyclasses import FlexyClass from .traversal import FailedParamSpec, traverse @@ -20,7 +22,10 @@ DEFAULT: Any = "***" +T = TypeVar("T") +R = TypeVar("R") C = TypeVar("C", bound=Union[DictConfig, ListConfig]) +PS = ParamSpec('PS') def cast(config: Any, copy: bool = False) -> DictConfig: @@ -47,11 +52,30 @@ def from_none(*args: Any, **kwargs: Any) -> DictConfig: return OmegaConf.create() +def _from_allow_none_or_skip( + fn: Callable[Concatenate[T, PS], R] +) -> Callable[Concatenate[Union[T, DictConfig, ListConfig, None], PS], R]: + """Decorator that creates an empty config if the input is None, + or returns the input if it is already a DictConfig or ListConfig""" + + def wrapped( + config: Union[T, None], + *args: PS.args, + **kwargs: PS.kwargs + ) -> R: + if config is None: + return from_none() + if isinstance(config, (DictConfig, ListConfig)): + return config + + return fn(config, *args, **kwargs) + + return wrapped + + +@_from_allow_none_or_skip def from_dataclass(config: Any) -> DictConfig: """Cast a dataclass to a structured omega config""" - if config is None: - return from_none() - if isclass(config) and issubclass(config, FlexyClass): config = config.defaults() @@ -77,6 +101,7 @@ def from_python(config: List[Any]) -> ListConfig: ... +@_from_allow_none_or_skip # type: ignore def from_python( config: Union[Dict[DictKeyType, Any], Dict[str, Any], List[Any]] ) -> Union[DictConfig, ListConfig]: @@ -94,6 +119,7 @@ def from_python( return parsed_config +@_from_allow_none_or_skip def from_dict( config: Union[Dict[DictKeyType, Any], Dict[str, Any]] ) -> DictConfig: @@ -104,6 +130,7 @@ def from_dict( return from_python(config) # type: ignore +@_from_allow_none_or_skip def from_string(config: str) -> DictConfig: """Load a config from a string""" if not isinstance(config, str): @@ -116,6 +143,7 @@ def from_string(config: str) -> DictConfig: return parsed_config +@_from_allow_none_or_skip def from_file(path: Union[str, Path]) -> DictConfig: """Load a config from a file, either YAML or JSON""" path = Path(path) @@ -142,6 +170,7 @@ def from_file(path: Union[str, Path]) -> DictConfig: return config +@_from_allow_none_or_skip def from_options(opts: Sequence[str]) -> DictConfig: """Create a config from a list of options""" if not isinstance(opts, abc.Sequence) or not all( diff --git a/src/springs/nicknames.py b/src/springs/nicknames.py index 3dca2e9..3de1554 100644 --- a/src/springs/nicknames.py +++ b/src/springs/nicknames.py @@ -16,6 +16,7 @@ cast, overload, ) +from typing_extensions import ParamSpec from omegaconf import DictConfig, ListConfig @@ -27,6 +28,7 @@ T = TypeVar("T") M = TypeVar("M", bound=RegistryValue) +P = ParamSpec("P") LOGGER = configure_logging(__name__) @@ -89,22 +91,32 @@ def _add(cls, name: str, config: M) -> M: cls.__registry__[name] = config return config + # @overload + # @classmethod + # def add(cls, name: str) -> Callable[[Type[T]], Type[T]]: + # ... + + # @overload + # @classmethod + # def add(cls, name: str) -> Callable[[Callable[P, T]], Callable[P, T]]: + # ... + @classmethod - def add(cls, name: str) -> Callable[[Type[T]], Type[T]]: + def add(cls, name: str) -> Callable[[Callable[P, T]], Callable[P, T]]: """Decorator to save a structured configuration with a nickname for easy reuse.""" - def add_to_registry(cls_: Type[T]) -> Type[T]: - if not ( - is_dataclass(cls_) - or isclass(cls_) - and issubclass(cls_, FlexyClass) - ): + def add_to_registry(cls_: Callable[P, T]) -> Callable[P, T]: + if is_dataclass(cls_): + pass + elif isclass(cls_) and issubclass(cls_, FlexyClass): + pass + else: raise ValueError(f"{cls_} must be a dataclass") if name in cls.__registry__: raise ValueError(f"{name} is already registered") - return cast(Type[T], cls._add(name, cls_)) + return cls._add(name, cls_) return add_to_registry @@ -135,7 +147,7 @@ def all(cls) -> Sequence[Tuple[str, str]]: return [ ( name, - str(config.__name__) + getattr(config, '__name__', type(config).__name__) if is_dataclass(config) else type(config).__name__, ) diff --git a/src/springs/rich_utils.py b/src/springs/rich_utils.py index e386a75..a4653ee 100644 --- a/src/springs/rich_utils.py +++ b/src/springs/rich_utils.py @@ -124,6 +124,12 @@ def add_pretty_traceback(**install_kwargs: Any) -> None: class RichArgumentParser(ArgumentParser): + theme: SpringsTheme + entrypoint: Optional[str] + arguments: Optional[str] + formatted: Dict[str, Any] + console_kwargs: Dict[str, Any] + def __init__( self, *args, From edf9925e66c643fbb8aa9112420befed41bb9862 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 27 Feb 2023 14:12:31 -0800 Subject: [PATCH 2/3] adding more lenient from support --- tests/test_create.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 tests/test_create.py diff --git a/tests/test_create.py b/tests/test_create.py new file mode 100644 index 0000000..5800a48 --- /dev/null +++ b/tests/test_create.py @@ -0,0 +1,41 @@ +from tempfile import NamedTemporaryFile +import unittest + +from omegaconf import DictConfig, ListConfig +import springs as sp + + +@sp.dataclass +class DT: + foo: str = 'bar' + + +class TestCreation(unittest.TestCase): + def test_from_dict(self): + self.assertEqual( + sp.to_dict(sp.from_dict({"foo": "bar"})), {"foo": "bar"} + ) + self.assertEqual(sp.to_dict(sp.from_dict(None)), {}) + self.assertEqual(sp.to_dict(sp.from_dict(DictConfig({}))), {}) + + def test_from_dataclass(self): + self.assertEqual(sp.to_dict(sp.from_dataclass(DT)), {'foo': 'bar'}) + self.assertEqual(sp.to_dict(sp.from_dataclass(None)), {}) + self.assertEqual(sp.to_dict(sp.from_dataclass(DictConfig({}))), {}) + + def test_from_python(self): + self.assertEqual( + sp.to_python(sp.from_python({"foo": "bar"})), {"foo": "bar"} + ) + self.assertEqual( + sp.to_python(sp.from_python(None)), {} # type: ignore + ) + self.assertEqual( + sp.to_python(sp.from_python(ListConfig([]))), [] # type: ignore + ) + + def test_from_file(self): + with NamedTemporaryFile('w') as f: + f.write('foo: bar') + f.flush() + self.assertEqual(sp.to_dict(sp.from_file(f.name)), {'foo': 'bar'}) From bcd6cfe780adba96b8083e32095b40a7566dda1b Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Mon, 27 Feb 2023 16:11:30 -0800 Subject: [PATCH 3/3] can nickname any func now --- src/springs/core.py | 17 ++++---- src/springs/initialize.py | 21 +++++++--- src/springs/nicknames.py | 84 +++++++++++++++++++++++++-------------- src/springs/shortcuts.py | 8 ++-- tests/test_create.py | 15 +++---- tests/test_nicknames.py | 56 ++++++++++++++++++++++++++ 6 files changed, 145 insertions(+), 56 deletions(-) diff --git a/src/springs/core.py b/src/springs/core.py index 858b42c..920c654 100644 --- a/src/springs/core.py +++ b/src/springs/core.py @@ -5,16 +5,15 @@ from functools import reduce from inspect import isclass from pathlib import Path -from typing import ( - Any, Dict, List, Sequence, Tuple, TypeVar, Union, Callable, - cast as typing_cast, overload -) +from typing import Any, Callable, Dict, List, Sequence, Tuple, TypeVar, Union +from typing import cast as typing_cast +from typing import overload from omegaconf import DictConfig, ListConfig, OmegaConf from omegaconf.errors import MissingMandatoryValue from omegaconf.omegaconf import DictKeyType +from typing_extensions import Concatenate, ParamSpec from yaml.scanner import ScannerError -from typing_extensions import ParamSpec, Concatenate from .flexyclasses import FlexyClass from .traversal import FailedParamSpec, traverse @@ -25,7 +24,7 @@ T = TypeVar("T") R = TypeVar("R") C = TypeVar("C", bound=Union[DictConfig, ListConfig]) -PS = ParamSpec('PS') +PS = ParamSpec("PS") def cast(config: Any, copy: bool = False) -> DictConfig: @@ -59,9 +58,7 @@ def _from_allow_none_or_skip( or returns the input if it is already a DictConfig or ListConfig""" def wrapped( - config: Union[T, None], - *args: PS.args, - **kwargs: PS.kwargs + config: Union[T, None], *args: PS.args, **kwargs: PS.kwargs ) -> R: if config is None: return from_none() @@ -101,7 +98,7 @@ def from_python(config: List[Any]) -> ListConfig: ... -@_from_allow_none_or_skip # type: ignore +@_from_allow_none_or_skip # type: ignore def from_python( config: Union[Dict[DictKeyType, Any], Dict[str, Any], List[Any]] ) -> Union[DictConfig, ListConfig]: diff --git a/src/springs/initialize.py b/src/springs/initialize.py index f44b067..711fcda 100644 --- a/src/springs/initialize.py +++ b/src/springs/initialize.py @@ -385,8 +385,8 @@ def _find_child_type( initializer, this function figures out the expected type using a combination of class annotations and __init__ annotations. - If the type the attribute should be cannot be determined, it - simply returns None. + If the type the attribute cannot be determined, it simply returns + None. An example: @@ -418,12 +418,21 @@ def __init__(self, a): return None if attr_name in (parent_anns := get_annotations(cls_)): + # this is good for cases where the attribute is annotated + # in the class definition return parent_anns[attr_name] - if attr_name in ( - parent_anns := inspect.getfullargspec(cls_).annotations - ): - return parent_anns[attr_name] + try: + # this fails on C types; we give up on nested initialization + # on those. + spec = inspect.getfullargspec(cls_) + except TypeError: + return None + + if attr_name in spec.annotations: + # this is good for cases where the attribute is annotated + # in a function definition + return spec.annotations[attr_name] # this is the case where we cannot resolve anything return None diff --git a/src/springs/nicknames.py b/src/springs/nicknames.py index 3de1554..6f72b47 100644 --- a/src/springs/nicknames.py +++ b/src/springs/nicknames.py @@ -1,8 +1,8 @@ +import inspect from dataclasses import is_dataclass from inspect import isclass from pathlib import Path from typing import ( - Any, Callable, Dict, Literal, @@ -13,23 +13,23 @@ Type, TypeVar, Union, - cast, overload, ) -from typing_extensions import ParamSpec -from omegaconf import DictConfig, ListConfig +from omegaconf import MISSING, DictConfig, ListConfig +from typing_extensions import ParamSpec -from .core import from_file +from .core import from_dict, from_file from .flexyclasses import FlexyClass from .logging import configure_logging -RegistryValue = Union[Type[Any], Type[FlexyClass], DictConfig, ListConfig] +RegistryValue = Union[Callable, Type[FlexyClass], DictConfig, ListConfig] +# M = TypeVar("M", bound=RegistryValue) T = TypeVar("T") -M = TypeVar("M", bound=RegistryValue) P = ParamSpec("P") + LOGGER = configure_logging(__name__) @@ -87,38 +87,62 @@ def scan( LOGGER.warning(f"Could not load config from {path}") @classmethod - def _add(cls, name: str, config: M) -> M: + def _add(cls, name: str, config: RegistryValue) -> RegistryValue: cls.__registry__[name] = config return config - # @overload - # @classmethod - # def add(cls, name: str) -> Callable[[Type[T]], Type[T]]: - # ... + @overload + @classmethod + def add(cls, name: str) -> Callable[[Type[T]], Type[T]]: + ... - # @overload - # @classmethod - # def add(cls, name: str) -> Callable[[Callable[P, T]], Callable[P, T]]: - # ... + @overload + @classmethod + def add( # type: ignore + cls, name: str + ) -> Callable[[Callable[P, T]], Callable[P, T]]: + ... @classmethod - def add(cls, name: str) -> Callable[[Callable[P, T]], Callable[P, T]]: + def add( + cls, name: str + ) -> Union[ + Callable[[Type[T]], Type[T]], + Callable[[Callable[P, T]], Callable[P, T]], + ]: """Decorator to save a structured configuration with a nickname for easy reuse.""" - def add_to_registry(cls_: Callable[P, T]) -> Callable[P, T]: - if is_dataclass(cls_): - pass - elif isclass(cls_) and issubclass(cls_, FlexyClass): - pass + if name in cls.__registry__: + raise ValueError(f"{name} is already registered") + + def add_to_registry( + cls_or_fn: Union[Type[T], Callable[P, T]] + ) -> Union[Type[T], Callable[P, T]]: + if is_dataclass(cls_or_fn): + # Pylance complains about dataclasses not being a valid type, + # but the problem is DataclassInstance is only defined within + # Pylance, so I can't type annotate with that. + cls._add(name, cls_or_fn) # pyright: ignore + elif isclass(cls_or_fn) and issubclass(cls_or_fn, FlexyClass): + cls._add(name, cls_or_fn) else: - raise ValueError(f"{cls_} must be a dataclass") - - if name in cls.__registry__: - raise ValueError(f"{name} is already registered") - return cls._add(name, cls_) - - return add_to_registry + from .initialize import Target, init + + sig = inspect.signature(cls_or_fn) + entry = from_dict( + { + init.TARGET: Target.to_string(cls_or_fn), + **{ + k: (v.default if v.default != v.empty else MISSING) + for k, v in sig.parameters.items() + }, + } + ) + cls._add(name, entry) + return cls_or_fn # type: ignore + + return add_to_registry # type: ignore @overload @classmethod @@ -147,7 +171,7 @@ def all(cls) -> Sequence[Tuple[str, str]]: return [ ( name, - getattr(config, '__name__', type(config).__name__) + getattr(config, "__name__", type(config).__name__) if is_dataclass(config) else type(config).__name__, ) diff --git a/src/springs/shortcuts.py b/src/springs/shortcuts.py index 8c74aae..fd4b69c 100644 --- a/src/springs/shortcuts.py +++ b/src/springs/shortcuts.py @@ -9,11 +9,12 @@ Optional, Sequence, Set, - Type, TypeVar, Union, ) +from typing_extensions import ParamSpec + from .field_utils import field from .flexyclasses import flexyclass from .initialize import Target @@ -22,6 +23,7 @@ from .utils import SpringsConfig, SpringsWarnings T = TypeVar("T") +P = ParamSpec("P") def get_nickname( @@ -45,9 +47,9 @@ def make_target(c: Callable) -> str: return Target.to_string(c) -def nickname(name: str) -> Callable[[Type[T]], Type[T]]: +def nickname(name: str) -> Callable[[T], T]: """Shortcut for springs.nicknames.NicknameRegistry.add""" - return NicknameRegistry.add(name) + return NicknameRegistry.add(name) # type: ignore def scan( diff --git a/tests/test_create.py b/tests/test_create.py index 5800a48..d026e2d 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,13 +1,14 @@ -from tempfile import NamedTemporaryFile import unittest +from tempfile import NamedTemporaryFile from omegaconf import DictConfig, ListConfig + import springs as sp @sp.dataclass class DT: - foo: str = 'bar' + foo: str = "bar" class TestCreation(unittest.TestCase): @@ -19,7 +20,7 @@ def test_from_dict(self): self.assertEqual(sp.to_dict(sp.from_dict(DictConfig({}))), {}) def test_from_dataclass(self): - self.assertEqual(sp.to_dict(sp.from_dataclass(DT)), {'foo': 'bar'}) + self.assertEqual(sp.to_dict(sp.from_dataclass(DT)), {"foo": "bar"}) self.assertEqual(sp.to_dict(sp.from_dataclass(None)), {}) self.assertEqual(sp.to_dict(sp.from_dataclass(DictConfig({}))), {}) @@ -31,11 +32,11 @@ def test_from_python(self): sp.to_python(sp.from_python(None)), {} # type: ignore ) self.assertEqual( - sp.to_python(sp.from_python(ListConfig([]))), [] # type: ignore + sp.to_python(sp.from_python(ListConfig([]))), [] # type: ignore ) def test_from_file(self): - with NamedTemporaryFile('w') as f: - f.write('foo: bar') + with NamedTemporaryFile("w") as f: + f.write("foo: bar") f.flush() - self.assertEqual(sp.to_dict(sp.from_file(f.name)), {'foo': 'bar'}) + self.assertEqual(sp.to_dict(sp.from_file(f.name)), {"foo": "bar"}) diff --git a/tests/test_nicknames.py b/tests/test_nicknames.py index 73a4bde..8bc1556 100644 --- a/tests/test_nicknames.py +++ b/tests/test_nicknames.py @@ -22,6 +22,17 @@ class DevConfig: batch_size: int = 32 +@sp.nickname("class_nickname") +class NC: + def __init__(self) -> None: + self.foo = "bar" + + +@sp.nickname("function_nickname") +def nf(text: str = "bar"): + return text + + class TestNicknames(unittest.TestCase): def setUp(self) -> None: self.cfg = sp.from_dict( @@ -64,3 +75,48 @@ def test_dict_nicknames(self): mod2 = sp.get_nickname("test/temp/config") self.assertTrue(isinstance(mod2, DictConfig)) self.assertEqual(sp.to_python(mod), sp.to_python(mod2)) + + def test_class_nickname(self): + mod = sp.get_nickname("class_nickname") + self.assertEqual( + mod, sp.from_dict({"_target_": "tests.test_nicknames.NC"}) + ) + + obj = sp.init.now(mod, NC) + self.assertTrue(isinstance(obj, NC)) + self.assertEqual(obj.foo, "bar") + + def test_function_nickname(self): + mod = sp.get_nickname("function_nickname") + self.assertEqual( + mod, + sp.from_dict( + { + "_target_": "tests.test_nicknames.nf", + "text": "bar", + } + ), + ) + + obj = sp.init.now(mod, str) + self.assertEqual(obj, "bar") + + obj = sp.init.now(mod, str, text="foo") + self.assertEqual(obj, "foo") + + def test_nickname_from_resolver(self): + mod = sp.from_string("func: ${sp.ref:function_nickname}") + mod = sp.resolve(mod) + self.assertEqual( + mod, + sp.from_python( + { + "func": { + "_target_": "tests.test_nicknames.nf", + "text": "bar", + } + } + ), + ) + obj = sp.init.now(mod.func, str, text="foo") + self.assertEqual(obj, "foo")