Skip to content

Commit

Permalink
can nickname any func now
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Feb 28, 2023
1 parent edf9925 commit bcd6cfe
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 56 deletions.
17 changes: 7 additions & 10 deletions src/springs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,15 @@
from functools import reduce
from inspect import isclass
from pathlib import Path
from typing import (
Any, Dict, List, Sequence, Tuple, TypeVar, Union, Callable,
cast as typing_cast, overload
)
from typing import Any, Callable, Dict, List, Sequence, Tuple, TypeVar, Union
from typing import cast as typing_cast
from typing import overload

from omegaconf import DictConfig, ListConfig, OmegaConf
from omegaconf.errors import MissingMandatoryValue
from omegaconf.omegaconf import DictKeyType
from typing_extensions import Concatenate, ParamSpec
from yaml.scanner import ScannerError
from typing_extensions import ParamSpec, Concatenate

from .flexyclasses import FlexyClass
from .traversal import FailedParamSpec, traverse
Expand All @@ -25,7 +24,7 @@
T = TypeVar("T")
R = TypeVar("R")
C = TypeVar("C", bound=Union[DictConfig, ListConfig])
PS = ParamSpec('PS')
PS = ParamSpec("PS")


def cast(config: Any, copy: bool = False) -> DictConfig:
Expand Down Expand Up @@ -59,9 +58,7 @@ def _from_allow_none_or_skip(
or returns the input if it is already a DictConfig or ListConfig"""

def wrapped(
config: Union[T, None],
*args: PS.args,
**kwargs: PS.kwargs
config: Union[T, None], *args: PS.args, **kwargs: PS.kwargs
) -> R:
if config is None:
return from_none()
Expand Down Expand Up @@ -101,7 +98,7 @@ def from_python(config: List[Any]) -> ListConfig:
...


@_from_allow_none_or_skip # type: ignore
@_from_allow_none_or_skip # type: ignore
def from_python(
config: Union[Dict[DictKeyType, Any], Dict[str, Any], List[Any]]
) -> Union[DictConfig, ListConfig]:
Expand Down
21 changes: 15 additions & 6 deletions src/springs/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,8 @@ def _find_child_type(
initializer, this function figures out the expected type using
a combination of class annotations and __init__ annotations.
If the type the attribute should be cannot be determined, it
simply returns None.
If the type the attribute cannot be determined, it simply returns
None.
An example:
Expand Down Expand Up @@ -418,12 +418,21 @@ def __init__(self, a):
return None

if attr_name in (parent_anns := get_annotations(cls_)):
# this is good for cases where the attribute is annotated
# in the class definition
return parent_anns[attr_name]

if attr_name in (
parent_anns := inspect.getfullargspec(cls_).annotations
):
return parent_anns[attr_name]
try:
# this fails on C types; we give up on nested initialization
# on those.
spec = inspect.getfullargspec(cls_)
except TypeError:
return None

if attr_name in spec.annotations:
# this is good for cases where the attribute is annotated
# in a function definition
return spec.annotations[attr_name]

# this is the case where we cannot resolve anything
return None
Expand Down
84 changes: 54 additions & 30 deletions src/springs/nicknames.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import inspect
from dataclasses import is_dataclass
from inspect import isclass
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
Literal,
Expand All @@ -13,23 +13,23 @@
Type,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import ParamSpec

from omegaconf import DictConfig, ListConfig
from omegaconf import MISSING, DictConfig, ListConfig
from typing_extensions import ParamSpec

from .core import from_file
from .core import from_dict, from_file
from .flexyclasses import FlexyClass
from .logging import configure_logging

RegistryValue = Union[Type[Any], Type[FlexyClass], DictConfig, ListConfig]
RegistryValue = Union[Callable, Type[FlexyClass], DictConfig, ListConfig]
# M = TypeVar("M", bound=RegistryValue)

T = TypeVar("T")
M = TypeVar("M", bound=RegistryValue)
P = ParamSpec("P")


LOGGER = configure_logging(__name__)


Expand Down Expand Up @@ -87,38 +87,62 @@ def scan(
LOGGER.warning(f"Could not load config from {path}")

@classmethod
def _add(cls, name: str, config: M) -> M:
def _add(cls, name: str, config: RegistryValue) -> RegistryValue:
cls.__registry__[name] = config
return config

# @overload
# @classmethod
# def add(cls, name: str) -> Callable[[Type[T]], Type[T]]:
# ...
@overload
@classmethod
def add(cls, name: str) -> Callable[[Type[T]], Type[T]]:
...

# @overload
# @classmethod
# def add(cls, name: str) -> Callable[[Callable[P, T]], Callable[P, T]]:
# ...
@overload
@classmethod
def add( # type: ignore
cls, name: str
) -> Callable[[Callable[P, T]], Callable[P, T]]:
...

@classmethod
def add(cls, name: str) -> Callable[[Callable[P, T]], Callable[P, T]]:
def add(
cls, name: str
) -> Union[
Callable[[Type[T]], Type[T]],
Callable[[Callable[P, T]], Callable[P, T]],
]:
"""Decorator to save a structured configuration with a nickname
for easy reuse."""

def add_to_registry(cls_: Callable[P, T]) -> Callable[P, T]:
if is_dataclass(cls_):
pass
elif isclass(cls_) and issubclass(cls_, FlexyClass):
pass
if name in cls.__registry__:
raise ValueError(f"{name} is already registered")

def add_to_registry(
cls_or_fn: Union[Type[T], Callable[P, T]]
) -> Union[Type[T], Callable[P, T]]:
if is_dataclass(cls_or_fn):
# Pylance complains about dataclasses not being a valid type,
# but the problem is DataclassInstance is only defined within
# Pylance, so I can't type annotate with that.
cls._add(name, cls_or_fn) # pyright: ignore
elif isclass(cls_or_fn) and issubclass(cls_or_fn, FlexyClass):
cls._add(name, cls_or_fn)
else:
raise ValueError(f"{cls_} must be a dataclass")

if name in cls.__registry__:
raise ValueError(f"{name} is already registered")
return cls._add(name, cls_)

return add_to_registry
from .initialize import Target, init

sig = inspect.signature(cls_or_fn)
entry = from_dict(
{
init.TARGET: Target.to_string(cls_or_fn),
**{
k: (v.default if v.default != v.empty else MISSING)
for k, v in sig.parameters.items()
},
}
)
cls._add(name, entry)
return cls_or_fn # type: ignore

return add_to_registry # type: ignore

@overload
@classmethod
Expand Down Expand Up @@ -147,7 +171,7 @@ def all(cls) -> Sequence[Tuple[str, str]]:
return [
(
name,
getattr(config, '__name__', type(config).__name__)
getattr(config, "__name__", type(config).__name__)
if is_dataclass(config)
else type(config).__name__,
)
Expand Down
8 changes: 5 additions & 3 deletions src/springs/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
Optional,
Sequence,
Set,
Type,
TypeVar,
Union,
)

from typing_extensions import ParamSpec

from .field_utils import field
from .flexyclasses import flexyclass
from .initialize import Target
Expand All @@ -22,6 +23,7 @@
from .utils import SpringsConfig, SpringsWarnings

T = TypeVar("T")
P = ParamSpec("P")


def get_nickname(
Expand All @@ -45,9 +47,9 @@ def make_target(c: Callable) -> str:
return Target.to_string(c)


def nickname(name: str) -> Callable[[Type[T]], Type[T]]:
def nickname(name: str) -> Callable[[T], T]:
"""Shortcut for springs.nicknames.NicknameRegistry.add"""
return NicknameRegistry.add(name)
return NicknameRegistry.add(name) # type: ignore


def scan(
Expand Down
15 changes: 8 additions & 7 deletions tests/test_create.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from tempfile import NamedTemporaryFile
import unittest
from tempfile import NamedTemporaryFile

from omegaconf import DictConfig, ListConfig

import springs as sp


@sp.dataclass
class DT:
foo: str = 'bar'
foo: str = "bar"


class TestCreation(unittest.TestCase):
Expand All @@ -19,7 +20,7 @@ def test_from_dict(self):
self.assertEqual(sp.to_dict(sp.from_dict(DictConfig({}))), {})

def test_from_dataclass(self):
self.assertEqual(sp.to_dict(sp.from_dataclass(DT)), {'foo': 'bar'})
self.assertEqual(sp.to_dict(sp.from_dataclass(DT)), {"foo": "bar"})
self.assertEqual(sp.to_dict(sp.from_dataclass(None)), {})
self.assertEqual(sp.to_dict(sp.from_dataclass(DictConfig({}))), {})

Expand All @@ -31,11 +32,11 @@ def test_from_python(self):
sp.to_python(sp.from_python(None)), {} # type: ignore
)
self.assertEqual(
sp.to_python(sp.from_python(ListConfig([]))), [] # type: ignore
sp.to_python(sp.from_python(ListConfig([]))), [] # type: ignore
)

def test_from_file(self):
with NamedTemporaryFile('w') as f:
f.write('foo: bar')
with NamedTemporaryFile("w") as f:
f.write("foo: bar")
f.flush()
self.assertEqual(sp.to_dict(sp.from_file(f.name)), {'foo': 'bar'})
self.assertEqual(sp.to_dict(sp.from_file(f.name)), {"foo": "bar"})
56 changes: 56 additions & 0 deletions tests/test_nicknames.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ class DevConfig:
batch_size: int = 32


@sp.nickname("class_nickname")
class NC:
def __init__(self) -> None:
self.foo = "bar"


@sp.nickname("function_nickname")
def nf(text: str = "bar"):
return text


class TestNicknames(unittest.TestCase):
def setUp(self) -> None:
self.cfg = sp.from_dict(
Expand Down Expand Up @@ -64,3 +75,48 @@ def test_dict_nicknames(self):
mod2 = sp.get_nickname("test/temp/config")
self.assertTrue(isinstance(mod2, DictConfig))
self.assertEqual(sp.to_python(mod), sp.to_python(mod2))

def test_class_nickname(self):
mod = sp.get_nickname("class_nickname")
self.assertEqual(
mod, sp.from_dict({"_target_": "tests.test_nicknames.NC"})
)

obj = sp.init.now(mod, NC)
self.assertTrue(isinstance(obj, NC))
self.assertEqual(obj.foo, "bar")

def test_function_nickname(self):
mod = sp.get_nickname("function_nickname")
self.assertEqual(
mod,
sp.from_dict(
{
"_target_": "tests.test_nicknames.nf",
"text": "bar",
}
),
)

obj = sp.init.now(mod, str)
self.assertEqual(obj, "bar")

obj = sp.init.now(mod, str, text="foo")
self.assertEqual(obj, "foo")

def test_nickname_from_resolver(self):
mod = sp.from_string("func: ${sp.ref:function_nickname}")
mod = sp.resolve(mod)
self.assertEqual(
mod,
sp.from_python(
{
"func": {
"_target_": "tests.test_nicknames.nf",
"text": "bar",
}
}
),
)
obj = sp.init.now(mod.func, str, text="foo")
self.assertEqual(obj, "foo")

0 comments on commit bcd6cfe

Please sign in to comment.