Skip to content

Commit

Permalink
Merge pull request #13 from soldni/soldni/nicknames
Browse files Browse the repository at this point in the history
Support Nickname for any function or class
  • Loading branch information
soldni authored Feb 28, 2023
2 parents 9c45fc5 + bcd6cfe commit 0a0ca46
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 35 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "springs"
version = "1.11.1"
version = "1.12"
description = """\
A set of utilities to create and manage typed configuration files \
effectively, built on top of OmegaConf.\
Expand Down
34 changes: 30 additions & 4 deletions src/springs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from functools import reduce
from inspect import isclass
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, TypeVar, Union
from typing import cast as typing_cast
from typing import overload

from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf.errors import MissingMandatoryValue
from omegaconf.omegaconf import DictKeyType
from typing_extensions import Concatenate, ParamSpec
from yaml.scanner import ScannerError

from .flexyclasses import FlexyClass
Expand All @@ -20,7 +21,10 @@
DEFAULT: Any = "***"


T = TypeVar("T")
R = TypeVar("R")
C = TypeVar("C", bound=Union[DictConfig, ListConfig])
PS = ParamSpec("PS")


def cast(config: Any, copy: bool = False) -> DictConfig:
Expand All @@ -47,11 +51,28 @@ def from_none(*args: Any, **kwargs: Any) -> DictConfig:
return OmegaConf.create()


def _from_allow_none_or_skip(
fn: Callable[Concatenate[T, PS], R]
) -> Callable[Concatenate[Union[T, DictConfig, ListConfig, None], PS], R]:
"""Decorator that creates an empty config if the input is None,
or returns the input if it is already a DictConfig or ListConfig"""

def wrapped(
config: Union[T, None], *args: PS.args, **kwargs: PS.kwargs
) -> R:
if config is None:
return from_none()
if isinstance(config, (DictConfig, ListConfig)):
return config

return fn(config, *args, **kwargs)

return wrapped


@_from_allow_none_or_skip
def from_dataclass(config: Any) -> DictConfig:
"""Cast a dataclass to a structured omega config"""
if config is None:
return from_none()

if isclass(config) and issubclass(config, FlexyClass):
config = config.defaults()

Expand All @@ -77,6 +98,7 @@ def from_python(config: List[Any]) -> ListConfig:
...


@_from_allow_none_or_skip # type: ignore
def from_python(
config: Union[Dict[DictKeyType, Any], Dict[str, Any], List[Any]]
) -> Union[DictConfig, ListConfig]:
Expand All @@ -94,6 +116,7 @@ def from_python(
return parsed_config


@_from_allow_none_or_skip
def from_dict(
config: Union[Dict[DictKeyType, Any], Dict[str, Any]]
) -> DictConfig:
Expand All @@ -104,6 +127,7 @@ def from_dict(
return from_python(config) # type: ignore


@_from_allow_none_or_skip
def from_string(config: str) -> DictConfig:
"""Load a config from a string"""
if not isinstance(config, str):
Expand All @@ -116,6 +140,7 @@ def from_string(config: str) -> DictConfig:
return parsed_config


@_from_allow_none_or_skip
def from_file(path: Union[str, Path]) -> DictConfig:
"""Load a config from a file, either YAML or JSON"""
path = Path(path)
Expand All @@ -142,6 +167,7 @@ def from_file(path: Union[str, Path]) -> DictConfig:
return config


@_from_allow_none_or_skip
def from_options(opts: Sequence[str]) -> DictConfig:
"""Create a config from a list of options"""
if not isinstance(opts, abc.Sequence) or not all(
Expand Down
21 changes: 15 additions & 6 deletions src/springs/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def _find_child_type(
initializer, this function figures out the expected type using
a combination of class annotations and __init__ annotations.
If the type the attribute should be cannot be determined, it
simply returns None.
If the type the attribute cannot be determined, it simply returns
None.
An example:
Expand Down Expand Up @@ -418,12 +418,21 @@ def __init__(self, a):
return None

if attr_name in (parent_anns := get_annotations(cls_)):
# this is good for cases where the attribute is annotated
# in the class definition
return parent_anns[attr_name]

if attr_name in (
parent_anns := inspect.getfullargspec(cls_).annotations
):
return parent_anns[attr_name]
try:
# this fails on C types; we give up on nested initialization
# on those.
spec = inspect.getfullargspec(cls_)
except TypeError:
return None

if attr_name in spec.annotations:
# this is good for cases where the attribute is annotated
# in a function definition
return spec.annotations[attr_name]

# this is the case where we cannot resolve anything
return None
Expand Down
78 changes: 57 additions & 21 deletions src/springs/nicknames.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import inspect
from dataclasses import is_dataclass
from inspect import isclass
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Literal,
Expand All @@ -13,20 +13,22 @@
Type,
TypeVar,
Union,
cast,
overload,
)

from omegaconf import DictConfig, ListConfig
from omegaconf import MISSING, DictConfig, ListConfig
from typing_extensions import ParamSpec

from .core import from_file
from .core import from_dict, from_file
from .flexyclasses import FlexyClass
from .logging import configure_logging

RegistryValue = Union[Type[Any], Type[FlexyClass], DictConfig, ListConfig]
RegistryValue = Union[Callable, Type[FlexyClass], DictConfig, ListConfig]
# M = TypeVar("M", bound=RegistryValue)

T = TypeVar("T")
M = TypeVar("M", bound=RegistryValue)
P = ParamSpec("P")


LOGGER = configure_logging(__name__)

Expand Down Expand Up @@ -85,28 +87,62 @@ def scan(
LOGGER.warning(f"Could not load config from {path}")

@classmethod
def _add(cls, name: str, config: M) -> M:
def _add(cls, name: str, config: RegistryValue) -> RegistryValue:
cls.__registry__[name] = config
return config

@overload
@classmethod
def add(cls, name: str) -> Callable[[Type[T]], Type[T]]:
"""Decorator to save a structured configuration with a nickname
for easy reuse."""
...

def add_to_registry(cls_: Type[T]) -> Type[T]:
if not (
is_dataclass(cls_)
or isclass(cls_)
and issubclass(cls_, FlexyClass)
):
raise ValueError(f"{cls_} must be a dataclass")
@overload
@classmethod
def add( # type: ignore
cls, name: str
) -> Callable[[Callable[P, T]], Callable[P, T]]:
...

if name in cls.__registry__:
raise ValueError(f"{name} is already registered")
return cast(Type[T], cls._add(name, cls_))
@classmethod
def add(
cls, name: str
) -> Union[
Callable[[Type[T]], Type[T]],
Callable[[Callable[P, T]], Callable[P, T]],
]:
"""Decorator to save a structured configuration with a nickname
for easy reuse."""

return add_to_registry
if name in cls.__registry__:
raise ValueError(f"{name} is already registered")

def add_to_registry(
cls_or_fn: Union[Type[T], Callable[P, T]]
) -> Union[Type[T], Callable[P, T]]:
if is_dataclass(cls_or_fn):
# Pylance complains about dataclasses not being a valid type,
# but the problem is DataclassInstance is only defined within
# Pylance, so I can't type annotate with that.
cls._add(name, cls_or_fn) # pyright: ignore
elif isclass(cls_or_fn) and issubclass(cls_or_fn, FlexyClass):
cls._add(name, cls_or_fn)
else:
from .initialize import Target, init

sig = inspect.signature(cls_or_fn)
entry = from_dict(
{
init.TARGET: Target.to_string(cls_or_fn),
**{
k: (v.default if v.default != v.empty else MISSING)
for k, v in sig.parameters.items()
},
}
)
cls._add(name, entry)
return cls_or_fn # type: ignore

return add_to_registry # type: ignore

@overload
@classmethod
Expand Down Expand Up @@ -135,7 +171,7 @@ def all(cls) -> Sequence[Tuple[str, str]]:
return [
(
name,
str(config.__name__)
getattr(config, "__name__", type(config).__name__)
if is_dataclass(config)
else type(config).__name__,
)
Expand Down
6 changes: 6 additions & 0 deletions src/springs/rich_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def add_pretty_traceback(**install_kwargs: Any) -> None:


class RichArgumentParser(ArgumentParser):
theme: SpringsTheme
entrypoint: Optional[str]
arguments: Optional[str]
formatted: Dict[str, Any]
console_kwargs: Dict[str, Any]

def __init__(
self,
*args,
Expand Down
8 changes: 5 additions & 3 deletions src/springs/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
)

from typing_extensions import ParamSpec

from .field_utils import field
from .flexyclasses import flexyclass
from .initialize import Target
Expand All @@ -22,6 +23,7 @@
from .utils import SpringsConfig, SpringsWarnings

T = TypeVar("T")
P = ParamSpec("P")


def get_nickname(
Expand All @@ -45,9 +47,9 @@ def make_target(c: Callable) -> str:
return Target.to_string(c)


def nickname(name: str) -> Callable[[Type[T]], Type[T]]:
def nickname(name: str) -> Callable[[T], T]:
"""Shortcut for springs.nicknames.NicknameRegistry.add"""
return NicknameRegistry.add(name)
return NicknameRegistry.add(name) # type: ignore


def scan(
Expand Down
42 changes: 42 additions & 0 deletions tests/test_create.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import unittest
from tempfile import NamedTemporaryFile

from omegaconf import DictConfig, ListConfig

import springs as sp


@sp.dataclass
class DT:
foo: str = "bar"


class TestCreation(unittest.TestCase):
def test_from_dict(self):
self.assertEqual(
sp.to_dict(sp.from_dict({"foo": "bar"})), {"foo": "bar"}
)
self.assertEqual(sp.to_dict(sp.from_dict(None)), {})
self.assertEqual(sp.to_dict(sp.from_dict(DictConfig({}))), {})

def test_from_dataclass(self):
self.assertEqual(sp.to_dict(sp.from_dataclass(DT)), {"foo": "bar"})
self.assertEqual(sp.to_dict(sp.from_dataclass(None)), {})
self.assertEqual(sp.to_dict(sp.from_dataclass(DictConfig({}))), {})

def test_from_python(self):
self.assertEqual(
sp.to_python(sp.from_python({"foo": "bar"})), {"foo": "bar"}
)
self.assertEqual(
sp.to_python(sp.from_python(None)), {} # type: ignore
)
self.assertEqual(
sp.to_python(sp.from_python(ListConfig([]))), [] # type: ignore
)

def test_from_file(self):
with NamedTemporaryFile("w") as f:
f.write("foo: bar")
f.flush()
self.assertEqual(sp.to_dict(sp.from_file(f.name)), {"foo": "bar"})
Loading

0 comments on commit 0a0ca46

Please sign in to comment.