diff --git a/pyproject.toml b/pyproject.toml index fbba8c8..e901504 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "springs" -version = "1.3.2.2" +version = "1.3.3" description = "A set of utilities to create and manage configuration files effectively, built on top of OmegaConf." authors = [ {name = "Luca Soldaini", email = "luca@soldaini.net" } diff --git a/springs/flexyclasses.py b/springs/flexyclasses.py index cf2cc76..3323319 100644 --- a/springs/flexyclasses.py +++ b/springs/flexyclasses.py @@ -13,8 +13,7 @@ DT = TypeVar("DT") -class FlexyClass: - ... +FLEXY_FLAG = "__flexyclass__" @dataclass_transform() @@ -25,20 +24,34 @@ def make_flexy(cls_: Type[DT]) -> Type[DT]: if not inspect.isclass(cls_) or not is_dataclass(cls_): raise TypeError(f"flexyclass must decorate a dataclass, not {cls_}") - # type ignore is for pylance, which freaks out a bit otherwise - new_cls: Type[DT] = type( - f"FlexyClass{cls_.__name__}", (cls_, FlexyClass), {} # type: ignore - ) + setattr(cls_, FLEXY_FLAG, FLEXY_FLAG) + + return cls_ # pyright: ignore + + +def is_flexy(obj_or_cls: Any) -> bool: + """Returns true if a class or configuration is flexy, false otherwise""" - return new_cls + if isinstance(obj_or_cls, (DictConfig, ListConfig)): + # case 1: we got a omega dict/list config; in this case, we use + # the get_type function to peek inside it. + obj_or_cls = get_type(obj_or_cls) + + if not inspect.isclass(obj_or_cls): + # case 2: we got a generic instance of a dataclass; then we get + # the type directly + obj_or_cls = type(obj_or_cls) + + # finally, we look for the flag + return hasattr(obj_or_cls, FLEXY_FLAG) @dataclass_transform() -def flexyclass(cls: Type[DT]) -> Type[DT]: # type: ignore +def flexyclass(cls: Type[DT]) -> Type[DT]: """A flexyclass is like a dataclass, but it supports partial specification of properties.""" SpringsWarnings.flexyclass() - return make_flexy(dataclass(cls)) # type: ignore + return make_flexy(dataclass(cls)) def flexy_field(type_: Type[DT], /, **kwargs: Any) -> DT: @@ -51,7 +64,7 @@ def flexy_field(type_: Type[DT], /, **kwargs: Any) -> DT: """ SpringsWarnings.flexyfield() - if not issubclass(type_, FlexyClass) and not is_dataclass(type_): + if not is_dataclass(type_) and not is_flexy(type_): raise TypeError(f"flexy_field must receive a flexyclass, not {type_}") # find the argument that are extra from what has been defined in @@ -89,6 +102,9 @@ def flexy_field(type_: Type[DT], /, **kwargs: Any) -> DT: def unlock_all_flexyclasses( cast_fn: Callable[..., DictOrListConfig] ) -> Callable[..., DictOrListConfig]: + """Unlock all flexy classes in a configuration so they can be + merged without causing disruption.""" + @wraps(cast_fn) def unlock_fn(*args, **kwargs): config_node = cast_fn(*args, **kwargs) @@ -102,7 +118,7 @@ def unlock_fn(*args, **kwargs): continue typ_ = get_type(spec.value) - if typ_ and issubclass(typ_, FlexyClass): + if typ_ and is_flexy(typ_): OmegaConf.set_struct(spec.value, False) return config_node