From 4ba294651a190a542aefd83d6fd70d3350b21bc7 Mon Sep 17 00:00:00 2001 From: bnewm0609 Date: Mon, 31 Jul 2023 16:11:37 -0700 Subject: [PATCH 1/6] Make dataclasses hashable --- src/springs/commandline.py | 3 +++ tests/fixtures/full_config/config.py | 11 ++++++++--- tests/test_nested_init.py | 4 ++-- tests/test_nicknames.py | 2 +- tests/test_resolvers.py | 2 +- tests/test_resolvers_in_nested.py | 2 +- 6 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/springs/commandline.py b/src/springs/commandline.py index 1bdf6d6..f004752 100644 --- a/src/springs/commandline.py +++ b/src/springs/commandline.py @@ -91,6 +91,9 @@ def add_argparse(self, parser: RichArgumentParser) -> Action: def __str__(self) -> str: return f"{self.short}/{self.long}" + + def __hash__(self) -> int: + return hash(str(self)) @dataclass diff --git a/tests/fixtures/full_config/config.py b/tests/fixtures/full_config/config.py index 337bed4..506af72 100644 --- a/tests/fixtures/full_config/config.py +++ b/tests/fixtures/full_config/config.py @@ -50,7 +50,7 @@ class DataConfig: test_splits_config: List[DataSplitConfig] = sp.field(default_factory=list) -@sp.dataclass +@sp.dataclass(unsafe_hash=True) class EnvironmentConfig: root_dir: Optional[str] = "~/plruns" run_name: Optional[str] = "sse" @@ -58,7 +58,7 @@ class EnvironmentConfig: seed: int = 5663 -@sp.dataclass +@sp.dataclass(unsafe_hash=True) class ModelConfig: _target_: str = "sse.models.TokenClassificationModule" tokenizer: HuggingFaceModuleConfig = HuggingFaceModuleConfig( @@ -98,7 +98,7 @@ class TextLoggerConfig: version: str = "" -@sp.dataclass +@sp.dataclass(unsafe_hash=True) class LoggersConfig: graphic: GraphicLoggerConfig = GraphicLoggerConfig() text: TextLoggerConfig = TextLoggerConfig() @@ -140,6 +140,11 @@ class SseConfig: checkpoint: Optional[str] = None # this controls training environment and data + # env: EnvironmentConfig = sp.fobj(EnvironmentConfig()) + # data: DataConfig = sp.fobj(DataConfig(), help="Data configuration") + # model: ModelConfig = sp.fobj(ModelConfig()) + # loggers: LoggersConfig = sp.fobj(LoggersConfig()) + # trainer: TrainerConfig = sp.fobj(TrainerConfig()) env: EnvironmentConfig = EnvironmentConfig() data: DataConfig = sp.fobj(DataConfig(), help="Data configuration") model: ModelConfig = ModelConfig() diff --git a/tests/test_nested_init.py b/tests/test_nested_init.py index 86e1e52..6230261 100644 --- a/tests/test_nested_init.py +++ b/tests/test_nested_init.py @@ -15,13 +15,13 @@ def __init__(self, a: Inner, b: int) -> None: self.b = b -@dataclass +@dataclass(unsafe_hash=True) class InnerConfig: _target_: str = Target.to_string(Inner) a: int = 1 -@dataclass +@dataclass(unsafe_hash=True) class OuterConfig: _target_: str = Target.to_string(Outer) a: InnerConfig = InnerConfig() diff --git a/tests/test_nicknames.py b/tests/test_nicknames.py index 8bc1556..2d5027b 100644 --- a/tests/test_nicknames.py +++ b/tests/test_nicknames.py @@ -9,7 +9,7 @@ import springs as sp -@dataclass +@dataclass(unsafe_hash=True) class DataConfig: path: str = sp.MISSING diff --git a/tests/test_resolvers.py b/tests/test_resolvers.py index a46b73d..28c0aa1 100644 --- a/tests/test_resolvers.py +++ b/tests/test_resolvers.py @@ -20,4 +20,4 @@ def test_sanitize(self): self.assertEqual(c.c, "___fooo___") self.assertEqual(c.d, c.b) self.assertEqual(c.e, "_f") - self.assertEqual(c.f, " fooo") + self.assertEqual(c.f, "fooo") diff --git a/tests/test_resolvers_in_nested.py b/tests/test_resolvers_in_nested.py index 27c3547..504c5f8 100644 --- a/tests/test_resolvers_in_nested.py +++ b/tests/test_resolvers_in_nested.py @@ -7,7 +7,7 @@ import springs as sp -@sp.dataclass +@sp.dataclass(unsafe_hash=True) class NestedConfig: value: int = 0 From 729a44dce750fca9f4087a14fee5d532c5a28caf Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 1 Aug 2023 16:27:06 -0700 Subject: [PATCH 2/6] style --- src/springs/commandline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/springs/commandline.py b/src/springs/commandline.py index f004752..d51e4bd 100644 --- a/src/springs/commandline.py +++ b/src/springs/commandline.py @@ -91,7 +91,7 @@ def add_argparse(self, parser: RichArgumentParser) -> Action: def __str__(self) -> str: return f"{self.short}/{self.long}" - + def __hash__(self) -> int: return hash(str(self)) From 0360ef68f50d7b640d301b9f07e9e53f8cf4eaf7 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 1 Aug 2023 17:27:16 -0700 Subject: [PATCH 3/6] type ignore --- src/springs/commandline.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/springs/commandline.py b/src/springs/commandline.py index d51e4bd..4b4286f 100644 --- a/src/springs/commandline.py +++ b/src/springs/commandline.py @@ -1,3 +1,4 @@ +import functools import re import sys from argparse import Action @@ -42,6 +43,7 @@ # parameters for the main function MP = ParamSpec("MP") +NP = ParamSpec("NP") # type for the configuration CT = TypeVar("CT") @@ -433,10 +435,8 @@ def wrap_main_method( def cli( config_node_cls: Optional[Type[CT]] = None, ) -> Callable[ - [ - # this is a main method that takes as first input a parsed config - Callable[Concatenate[CT, MP], RT] - ], + # this is a main method that takes as first input a parsed config + [Callable[Concatenate[CT, MP], RT]], # the decorated method doesn't expect the parsed config as first input, # since that will be parsed from the command line Callable[MP, RT], @@ -490,6 +490,7 @@ def main(cfg: Config): name = config_node_cls.__name__ def wrapper(func: Callable[Concatenate[CT, MP], RT]) -> Callable[MP, RT]: + @functools.wraps(func) def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT: # I could have used a functools.partial here, but defining # my own function instead allows me to provide nice typing @@ -504,4 +505,8 @@ def wrapping(*args: MP.args, **kwargs: MP.kwargs) -> RT: return wrapping - return wrapper + # TODO: figure out why mypy complains with the following error: + # Incompatible return value type (got "Callable[[Arg(Callable[[CT, + # **MP], RT], 'func')], Callable[MP, RT]]", expected + # "Callable[[Callable[[CT, **MP], RT]], Callable[MP, RT]]") + return wrapper # type: ignore From 12db9662ad805681cf1765a00d3e3e05e607a21a Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 1 Aug 2023 23:30:34 -0700 Subject: [PATCH 4/6] tests --- src/springs/__init__.py | 5 --- src/springs/commandline.py | 26 ++++++----- src/springs/flexyclasses.py | 22 +++++----- src/springs/rich_utils.py | 14 +++++- src/springs/shortcuts.py | 30 ++----------- tests/fixtures/full_config/config.py | 65 +++++++++++++++------------- tests/test_create.py | 3 ++ tests/test_flexyclass.py | 25 +++++------ tests/test_nested_init.py | 14 +++--- tests/test_new_merge.py | 8 ++-- tests/test_nicknames.py | 4 +- tests/test_resolvers_in_nested.py | 4 +- 12 files changed, 108 insertions(+), 112 deletions(-) diff --git a/src/springs/__init__.py b/src/springs/__init__.py index f4e3b6c..b3abfac 100644 --- a/src/springs/__init__.py +++ b/src/springs/__init__.py @@ -33,8 +33,6 @@ debug_logger, fdict, flist, - fobj, - fval, get_nickname, make_flexy, make_target, @@ -49,7 +47,6 @@ __version__ = get_version() __all__ = [ - "add_help", "all_resolvers", "cast", "cli", @@ -63,8 +60,6 @@ "field", "flexyclass", "flist", - "fobj", - "fval", "from_dataclass", "from_dict", "from_file", diff --git a/src/springs/commandline.py b/src/springs/commandline.py index 4b4286f..a1561ba 100644 --- a/src/springs/commandline.py +++ b/src/springs/commandline.py @@ -30,6 +30,7 @@ to_yaml, unsafe_merge, ) +from .field_utils import field from .flexyclasses import is_flexyclass from .logging import configure_logging from .nicknames import NicknameRegistry @@ -94,13 +95,14 @@ def add_argparse(self, parser: RichArgumentParser) -> Action: def __str__(self) -> str: return f"{self.short}/{self.long}" - def __hash__(self) -> int: - return hash(str(self)) + @classmethod + def field(cls, *args, **kwargs) -> "Flag": + return field(default_factory=lambda: cls(*args, **kwargs)) @dataclass class CliFlags: - config: Flag = Flag( + config: Flag = Flag.field( name="config", help=( "either a path to a YAML file containing a configuration, or " @@ -112,22 +114,22 @@ class CliFlags: action="append", metavar="/path/to/config.yaml", ) - options: Flag = Flag( + options: Flag = Flag.field( name="options", help="print all default options and CLI flags.", action="store_true", ) - inputs: Flag = Flag( + inputs: Flag = Flag.field( name="inputs", help="print the input configuration.", action="store_true", ) - parsed: Flag = Flag( + parsed: Flag = Flag.field( name="parsed", help="print the parsed configuration.", action="store_true", ) - log_level: Flag = Flag( + log_level: Flag = Flag.field( name="log-level", help=( "logging level to use for this program; can be one of " @@ -136,17 +138,17 @@ class CliFlags: default="WARNING", choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"], ) - debug: Flag = Flag( + debug: Flag = Flag.field( name="debug", help="enable debug mode; equivalent to '--log-level DEBUG'", action="store_true", ) - quiet: Flag = Flag( + quiet: Flag = Flag.field( name="quiet", help="if provided, it does not print the configuration when running", action="store_true", ) - resolvers: Flag = Flag( + resolvers: Flag = Flag.field( name="resolvers", help=( "print all registered resolvers in OmegaConf, " @@ -154,12 +156,12 @@ class CliFlags: ), action="store_true", ) - nicknames: Flag = Flag( + nicknames: Flag = Flag.field( name="nicknames", help="print all registered nicknames in Springs", action="store_true", ) - save: Flag = Flag( + save: Flag = Flag.field( name="save", help="save the configuration to a YAML file and exit", default=None, diff --git a/src/springs/flexyclasses.py b/src/springs/flexyclasses.py index 2969ef4..5fb8594 100644 --- a/src/springs/flexyclasses.py +++ b/src/springs/flexyclasses.py @@ -9,18 +9,18 @@ from .utils import get_annotations -C = TypeVar("C", bound=Any) +_C = TypeVar("_C", bound=Any) -class FlexyClass(dict, Generic[C]): +class FlexyClass(dict, Generic[_C]): """A FlexyClass is a dictionary with some default values assigned to it FlexyClasses are generally not used directly, but rather creating using the `flexyclass` decorator. NOTE: When instantiating a new FlexyClass object directly, the constructor - actually returns a `dataclasses.Field` object. This is for API consistency - with how dataclasses are used in a structured configuration. If you want to - access values in the FlexyClass directly, use FlexyClass.defaults property. + actually returns a `dict` object. This is for API consistency with how + dataclasses are used in a structured configuration. If you want to access + values in the FlexyClass directly, use FlexyClass.defaults property. """ __origin__: type = dict @@ -60,7 +60,8 @@ def __new__(cls, **kwargs): # to use flexyclasses in the same way they would use a dataclass. factory_dict: Dict[str, Any] = {} factory_dict = {**cls.defaults(), **kwargs} - return field(default_factory=lambda: factory_dict) + return factory_dict + # return field(default_factory=lambda: factory_dict) @classmethod def to_dict_config(cls, **kwargs: Any) -> DictConfig: @@ -70,7 +71,7 @@ def to_dict_config(cls, **kwargs: Any) -> DictConfig: return from_dict({**cls.defaults(), **kwargs}) @classmethod - def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]: + def flexyclass(cls, target_cls: Type[_C]) -> Type["FlexyClass[_C]"]: """Decorator to create a FlexyClass from a class""" if is_dataclass(target_cls): @@ -86,15 +87,16 @@ def flexyclass(cls, target_cls: Type[C]) -> Type["FlexyClass"]: for f_name, f_value in attributes_iterator } - return type( + rt = type( target_cls.__name__, (FlexyClass,), {"__flexyclass_defaults__": defaults}, ) + return rt -@dataclass_transform() -def flexyclass(cls: Type[C]) -> Type[FlexyClass[C]]: +@dataclass_transform(field_specifiers=(Field, field)) +def flexyclass(cls: Type[_C]) -> Type[FlexyClass[_C]]: """Alias for FlexyClass.flexyclass""" return FlexyClass.flexyclass(cls) diff --git a/src/springs/rich_utils.py b/src/springs/rich_utils.py index 712b1da..70c1d7d 100644 --- a/src/springs/rich_utils.py +++ b/src/springs/rich_utils.py @@ -2,7 +2,17 @@ import re from argparse import SUPPRESS, ArgumentParser from dataclasses import dataclass -from typing import IO, Any, Dict, Generator, List, Optional, Sequence, Union +from typing import ( + IO, + Any, + Dict, + Generator, + List, + Optional, + Sequence, + Tuple, + Union, +) from omegaconf import DictConfig, ListConfig from rich import box @@ -153,7 +163,7 @@ def format_usage(self): for ag in self._action_groups: for act in ag._group_actions: if isinstance(act.metavar, str): - metavar = (act.metavar,) + metavar: Tuple[str, ...] = (act.metavar,) elif act.metavar is None: metavar = (act.dest.upper(),) else: diff --git a/src/springs/shortcuts.py b/src/springs/shortcuts.py index fd4b69c..a20a0e2 100644 --- a/src/springs/shortcuts.py +++ b/src/springs/shortcuts.py @@ -73,37 +73,14 @@ def make_flexy(cls_: Any) -> Any: return flexyclass(cls_) -def fval(value: T, **kwargs) -> T: - """Shortcut for creating a Field with a default value. - - Args: - value: value returned by default factory""" - - return field(default=value, **kwargs) - - -def fobj(object: T, **kwargs) -> T: - """Shortcut for creating a Field with a default_factory that returns - a specific object. - - Args: - obj: object returned by default factory""" - - def _factory_fn() -> T: - # make a copy so that the same object isn't returned - # (it's a factory, not a singleton!) - return copy.deepcopy(object) - - return field(default_factory=_factory_fn, **kwargs) - - def fdict(**kwargs: Any) -> Dict[str, Any]: """Shortcut for creating a Field with a default_factory that returns a dictionary. Args: **kwargs: values for the dictionary returned by default factory""" - return fobj(kwargs) + kwargs = copy.deepcopy(kwargs) + return field(default_factory=lambda: kwargs) def flist(*args: Any) -> List[Any]: @@ -112,7 +89,8 @@ def flist(*args: Any) -> List[Any]: Args: *args: values for the list returned by default factory""" - return fobj(list(args)) + l_args = list(copy.deepcopy(args)) + return field(default_factory=lambda: l_args) def debug_logger(*args: Any, **kwargs: Any) -> Logger: diff --git a/tests/fixtures/full_config/config.py b/tests/fixtures/full_config/config.py index 506af72..9fd063b 100644 --- a/tests/fixtures/full_config/config.py +++ b/tests/fixtures/full_config/config.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union import springs as sp @@ -6,6 +5,11 @@ PL = "pytorch_lightning" +@sp.dataclass +class SpringsConfig: + foo: int = 1 + + @sp.flexyclass class TargetConfig: _target_: str = sp.MISSING @@ -20,7 +24,6 @@ class LoaderConfig: @sp.flexyclass -@dataclass class HuggingFaceModuleConfig: _target_: str = sp.MISSING pretrained_model_name_or_path: str = "${backbone}" @@ -33,7 +36,7 @@ class MapperConfig: @sp.dataclass class DataSplitConfig: - loader: LoaderConfig = LoaderConfig() + loader: LoaderConfig = sp.field(default_factory=LoaderConfig) mappers: List[MapperConfig] = sp.field(default_factory=list) @@ -44,13 +47,13 @@ class DataConfig: num_workers: int = 0 pin_memory: bool = False persistent_workers: bool = False - collator: TargetConfig = TargetConfig() + collator: TargetConfig = sp.field(default_factory=TargetConfig) train_splits_config: List[DataSplitConfig] = sp.field(default_factory=list) valid_splits_config: List[DataSplitConfig] = sp.field(default_factory=list) test_splits_config: List[DataSplitConfig] = sp.field(default_factory=list) -@sp.dataclass(unsafe_hash=True) +@sp.dataclass class EnvironmentConfig: root_dir: Optional[str] = "~/plruns" run_name: Optional[str] = "sse" @@ -58,15 +61,21 @@ class EnvironmentConfig: seed: int = 5663 -@sp.dataclass(unsafe_hash=True) +@sp.dataclass class ModelConfig: _target_: str = "sse.models.TokenClassificationModule" - tokenizer: HuggingFaceModuleConfig = HuggingFaceModuleConfig( - _target_="transformers.AutoTokenizer.from_pretrained" + tokenizer: HuggingFaceModuleConfig = sp.field( + default_factory=lambda: HuggingFaceModuleConfig( + _target_="transformers.AutoTokenizer.from_pretrained" + ) ) - transformer: HuggingFaceModuleConfig = HuggingFaceModuleConfig( - _target_=( - "transformers.AutoModelForSequenceClassification.from_pretrained" + transformer: HuggingFaceModuleConfig = sp.field( + default_factory=lambda: HuggingFaceModuleConfig( + _target_=( + "transformers." + "AutoModelForSequenceClassification." + "from_pretrained" + ) ) ) val_loss_label: str = "val_loss" @@ -98,10 +107,12 @@ class TextLoggerConfig: version: str = "" -@sp.dataclass(unsafe_hash=True) +@sp.dataclass class LoggersConfig: - graphic: GraphicLoggerConfig = GraphicLoggerConfig() - text: TextLoggerConfig = TextLoggerConfig() + graphic: GraphicLoggerConfig = sp.field( + default_factory=GraphicLoggerConfig + ) + text: TextLoggerConfig = sp.field(default_factory=TextLoggerConfig) @sp.flexyclass @@ -134,23 +145,19 @@ class TrainerConfig: @sp.dataclass class SseConfig: # base strings to control where models and tokenizers come from - backbone: Optional[str] = sp.fobj( - None, help="name of the transformers model to use" + backbone: Optional[str] = sp.field( + default=None, help="name of the transformers model to use" ) checkpoint: Optional[str] = None - # this controls training environment and data - # env: EnvironmentConfig = sp.fobj(EnvironmentConfig()) - # data: DataConfig = sp.fobj(DataConfig(), help="Data configuration") - # model: ModelConfig = sp.fobj(ModelConfig()) - # loggers: LoggersConfig = sp.fobj(LoggersConfig()) - # trainer: TrainerConfig = sp.fobj(TrainerConfig()) - env: EnvironmentConfig = EnvironmentConfig() - data: DataConfig = sp.fobj(DataConfig(), help="Data configuration") - model: ModelConfig = ModelConfig() - loggers: LoggersConfig = LoggersConfig() - trainer: TrainerConfig = TrainerConfig() - checkpointing: Optional[CheckpointConfig] = sp.fval( - None, help="optional configurations to deal with checkpointing" + env: EnvironmentConfig = sp.field(default_factory=EnvironmentConfig) + data: DataConfig = sp.field( + default_factory=DataConfig, help="Data configuration" + ) + model: ModelConfig = sp.field(default_factory=ModelConfig) + loggers: LoggersConfig = sp.field(default_factory=LoggersConfig) + trainer: TrainerConfig = sp.field(default_factory=TrainerConfig) + checkpointing: Optional[CheckpointConfig] = sp.field( + default=None, help="optional configurations to deal with checkpointing" ) early_stopping: Optional[EarlyStoppingConfig] = None diff --git a/tests/test_create.py b/tests/test_create.py index d026e2d..5983784 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -11,6 +11,9 @@ class DT: foo: str = "bar" +DT(foo="bar") + + class TestCreation(unittest.TestCase): def test_from_dict(self): self.assertEqual( diff --git a/tests/test_flexyclass.py b/tests/test_flexyclass.py index 6c7857a..22731e6 100644 --- a/tests/test_flexyclass.py +++ b/tests/test_flexyclass.py @@ -1,36 +1,35 @@ import unittest -from dataclasses import dataclass -from springs import MISSING, from_dataclass, from_dict, make_target -from springs.flexyclasses import flexyclass +import springs as sp -@flexyclass -@dataclass +@sp.flexyclass class FlexyConfig: - a: int = MISSING + a: int = sp.MISSING -@dataclass +@sp.dataclass class FlexyConfigContainer: - f1: FlexyConfig = FlexyConfig(a=1) - f2: FlexyConfig = FlexyConfig(a=1, b=2) # type: ignore + f1: FlexyConfig = sp.field(default_factory=lambda: FlexyConfig(a=1)) + f2: FlexyConfig = sp.field( + default_factory=lambda: FlexyConfig(a=1, b=2) # type: ignore + ) -@flexyclass +@sp.flexyclass class PipelineStepConfig: - _target_: str = MISSING + _target_: str = sp.MISSING class TestFlexyClass(unittest.TestCase): def test_flexyclass(self): di = {"a": 1, "b": 2} - config = from_dict({"_target_": make_target(FlexyConfig), **di}) + config = sp.from_dict({"_target_": sp.make_target(FlexyConfig), **di}) self.assertEqual(config.a, di["a"]) self.assertEqual(config.b, di["b"]) def test_flexyclass_container(self): - config = from_dataclass(FlexyConfigContainer) + config = sp.from_dataclass(FlexyConfigContainer) self.assertTrue( hasattr(config.f2, "b"), "FlexyConfigContainer.f1.b is not set" ) diff --git a/tests/test_nested_init.py b/tests/test_nested_init.py index 6230261..9580b23 100644 --- a/tests/test_nested_init.py +++ b/tests/test_nested_init.py @@ -1,7 +1,7 @@ import unittest from dataclasses import dataclass -from springs.initialize import Target, init +import springs as sp class Inner: @@ -15,23 +15,23 @@ def __init__(self, a: Inner, b: int) -> None: self.b = b -@dataclass(unsafe_hash=True) +@dataclass class InnerConfig: - _target_: str = Target.to_string(Inner) + _target_: str = sp.Target.to_string(Inner) a: int = 1 -@dataclass(unsafe_hash=True) +@dataclass class OuterConfig: - _target_: str = Target.to_string(Outer) - a: InnerConfig = InnerConfig() + _target_: str = sp.Target.to_string(Outer) + a: InnerConfig = sp.field(default_factory=InnerConfig) b: int = 2 class TestInit(unittest.TestCase): def test_nested_init(self): config = OuterConfig() - out = init.now(config, Outer) + out = sp.init.now(config, Outer) self.assertTrue(isinstance(out, Outer)) self.assertTrue(isinstance(out.a, Inner)) diff --git a/tests/test_new_merge.py b/tests/test_new_merge.py index e97cea4..d52cebc 100644 --- a/tests/test_new_merge.py +++ b/tests/test_new_merge.py @@ -1,6 +1,6 @@ import pickle import unittest -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Dict, Optional from omegaconf import OmegaConf @@ -8,6 +8,7 @@ from springs import DictConfig from springs.core import merge from springs.flexyclasses import FlexyClass +from springs.field_utils import field @FlexyClass.flexyclass @@ -16,10 +17,9 @@ class ObjNestedConfig: @FlexyClass.flexyclass -@dataclass class ObjConfig: _target_: str = "springs.core" - nest: ObjNestedConfig = ObjNestedConfig() + nest: ObjNestedConfig = field(default_factory=ObjNestedConfig) @dataclass @@ -33,7 +33,7 @@ class AppCfg: elsewhere: Optional[Any] = None foo: FooCfg = field(default_factory=FooCfg) bar: bool = False - c: ObjConfig = ObjConfig() + c: ObjConfig = field(default_factory=ObjConfig) cn: Dict[str, ObjConfig] = field(default_factory=dict) diff --git a/tests/test_nicknames.py b/tests/test_nicknames.py index 2d5027b..91915f2 100644 --- a/tests/test_nicknames.py +++ b/tests/test_nicknames.py @@ -9,7 +9,7 @@ import springs as sp -@dataclass(unsafe_hash=True) +@dataclass class DataConfig: path: str = sp.MISSING @@ -17,7 +17,7 @@ class DataConfig: @sp.nickname("dev_config") @sp.dataclass class DevConfig: - data: DataConfig = DataConfig(path="/dev") + data: DataConfig = sp.field(default_factory=lambda: DataConfig(path="/dev")) name: str = "dev" batch_size: int = 32 diff --git a/tests/test_resolvers_in_nested.py b/tests/test_resolvers_in_nested.py index 504c5f8..40a85a1 100644 --- a/tests/test_resolvers_in_nested.py +++ b/tests/test_resolvers_in_nested.py @@ -7,14 +7,14 @@ import springs as sp -@sp.dataclass(unsafe_hash=True) +@sp.dataclass class NestedConfig: value: int = 0 @sp.dataclass class Config: - nested: NestedConfig = NestedConfig() + nested: NestedConfig = sp.field(default_factory=NestedConfig) li: List[NestedConfig] = sp.field(default_factory=list) di: Dict[str, NestedConfig] = sp.field(default_factory=dict) From bf92b4e29c97319224b3269a5d2838ea108da6c6 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 1 Aug 2023 23:32:52 -0700 Subject: [PATCH 5/6] forgot formatting --- tests/test_new_merge.py | 2 +- tests/test_nicknames.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_new_merge.py b/tests/test_new_merge.py index d52cebc..1af0f37 100644 --- a/tests/test_new_merge.py +++ b/tests/test_new_merge.py @@ -7,8 +7,8 @@ from springs import DictConfig from springs.core import merge -from springs.flexyclasses import FlexyClass from springs.field_utils import field +from springs.flexyclasses import FlexyClass @FlexyClass.flexyclass diff --git a/tests/test_nicknames.py b/tests/test_nicknames.py index 91915f2..4bf6365 100644 --- a/tests/test_nicknames.py +++ b/tests/test_nicknames.py @@ -17,7 +17,9 @@ class DataConfig: @sp.nickname("dev_config") @sp.dataclass class DevConfig: - data: DataConfig = sp.field(default_factory=lambda: DataConfig(path="/dev")) + data: DataConfig = sp.field( + default_factory=lambda: DataConfig(path="/dev") + ) name: str = "dev" batch_size: int = 32 From 5ed3f5466863090bb04803d7a6f91c1473624c07 Mon Sep 17 00:00:00 2001 From: Luca Soldaini Date: Tue, 1 Aug 2023 23:37:48 -0700 Subject: [PATCH 6/6] bumped version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6329fd2..b5b11af 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "springs" -version = "1.12.3" +version = "1.13.0" description = """\ A set of utilities to create and manage typed configuration files \ effectively, built on top of OmegaConf.\