Skip to content

Commit

Permalink
misc improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
soldni committed Sep 2, 2022
1 parent 03bfc66 commit 824fda4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "springs"
version = "1.3.2.2"
version = "1.3.3"
description = "A set of utilities to create and manage configuration files effectively, built on top of OmegaConf."
authors = [
{name = "Luca Soldaini", email = "[email protected]" }
Expand Down
38 changes: 27 additions & 11 deletions springs/flexyclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
DT = TypeVar("DT")


class FlexyClass:
...
FLEXY_FLAG = "__flexyclass__"


@dataclass_transform()
Expand All @@ -25,20 +24,34 @@ def make_flexy(cls_: Type[DT]) -> Type[DT]:
if not inspect.isclass(cls_) or not is_dataclass(cls_):
raise TypeError(f"flexyclass must decorate a dataclass, not {cls_}")

# type ignore is for pylance, which freaks out a bit otherwise
new_cls: Type[DT] = type(
f"FlexyClass{cls_.__name__}", (cls_, FlexyClass), {} # type: ignore
)
setattr(cls_, FLEXY_FLAG, FLEXY_FLAG)

return cls_ # pyright: ignore


def is_flexy(obj_or_cls: Any) -> bool:
"""Returns true if a class or configuration is flexy, false otherwise"""

return new_cls
if isinstance(obj_or_cls, (DictConfig, ListConfig)):
# case 1: we got a omega dict/list config; in this case, we use
# the get_type function to peek inside it.
obj_or_cls = get_type(obj_or_cls)

if not inspect.isclass(obj_or_cls):
# case 2: we got a generic instance of a dataclass; then we get
# the type directly
obj_or_cls = type(obj_or_cls)

# finally, we look for the flag
return hasattr(obj_or_cls, FLEXY_FLAG)


@dataclass_transform()
def flexyclass(cls: Type[DT]) -> Type[DT]: # type: ignore
def flexyclass(cls: Type[DT]) -> Type[DT]:
"""A flexyclass is like a dataclass, but it supports partial
specification of properties."""
SpringsWarnings.flexyclass()
return make_flexy(dataclass(cls)) # type: ignore
return make_flexy(dataclass(cls))


def flexy_field(type_: Type[DT], /, **kwargs: Any) -> DT:
Expand All @@ -51,7 +64,7 @@ def flexy_field(type_: Type[DT], /, **kwargs: Any) -> DT:
"""
SpringsWarnings.flexyfield()

if not issubclass(type_, FlexyClass) and not is_dataclass(type_):
if not is_dataclass(type_) and not is_flexy(type_):
raise TypeError(f"flexy_field must receive a flexyclass, not {type_}")

# find the argument that are extra from what has been defined in
Expand Down Expand Up @@ -89,6 +102,9 @@ def flexy_field(type_: Type[DT], /, **kwargs: Any) -> DT:
def unlock_all_flexyclasses(
cast_fn: Callable[..., DictOrListConfig]
) -> Callable[..., DictOrListConfig]:
"""Unlock all flexy classes in a configuration so they can be
merged without causing disruption."""

@wraps(cast_fn)
def unlock_fn(*args, **kwargs):
config_node = cast_fn(*args, **kwargs)
Expand All @@ -102,7 +118,7 @@ def unlock_fn(*args, **kwargs):
continue

typ_ = get_type(spec.value)
if typ_ and issubclass(typ_, FlexyClass):
if typ_ and is_flexy(typ_):
OmegaConf.set_struct(spec.value, False)

return config_node
Expand Down

0 comments on commit 824fda4

Please sign in to comment.