diff --git a/README.md b/README.md index 56d38ef..87f7a28 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ What I need from a ML configuration library... import os from dataclasses import asdict, dataclass, field -from coqpit.coqpit import MISSING, Coqpit, check_argument +from coqpit import MISSING, Coqpit, check_argument @dataclass @@ -189,7 +189,7 @@ import os from dataclasses import asdict, dataclass, field from typing import List -from coqpit.coqpit import Coqpit, check_argument +from coqpit import Coqpit, check_argument import sys @@ -265,7 +265,7 @@ if __name__ == '__main__': ```python import os from dataclasses import dataclass -from coqpit.coqpit import Coqpit, check_argument +from coqpit import Coqpit, check_argument @dataclass diff --git a/coqpit/coqpit.py b/coqpit/coqpit.py index 0adc360..81edbd1 100644 --- a/coqpit/coqpit.py +++ b/coqpit/coqpit.py @@ -404,7 +404,15 @@ def _get_help(field): def _init_argparse( - parser, field_name, field_type, field_value, field_help, arg_prefix="", help_prefix="", relaxed_parser=False + parser, + field_name, + field_type, + field_value, + field_default_factory, + field_help, + arg_prefix="", + help_prefix="", + relaxed_parser=False, ): if field_value is None and not is_primitive_type(field_type) and not is_list(field_type): # aggregate types (fields with a Coqpit subclass as type) are not supported without None @@ -429,7 +437,7 @@ def _init_argparse( if is_list(list_field_type) and relaxed_parser: return parser - if field_value is None: + if field_value is None or field_default_factory is list: if not is_primitive_type(list_field_type) and not relaxed_parser: raise NotImplementedError(" [!] Empty list with non primitive inner type is currently not supported.") @@ -449,6 +457,7 @@ def _init_argparse( str(idx), list_field_type, fv, + field_default_factory, field_help="", help_prefix=f"{help_prefix} - ", arg_prefix=f"{arg_prefix}", @@ -739,9 +748,18 @@ def init_argparse( for field in class_fields: field_value = vars(self)[field.name] field_type = field.type + field_default_factory = field.default_factory field_help = _get_help(field) _init_argparse( - parser, field.name, field_type, field_value, field_help, arg_prefix, help_prefix, relaxed_parser + parser, + field.name, + field_type, + field_value, + field_default_factory, + field_help, + arg_prefix, + help_prefix, + relaxed_parser, ) return parser @@ -800,7 +818,7 @@ def check_argument( ), f" [!] prequested fields {prerequest} for {name} are not defined." # check if the path exists if is_path: - assert os.path.exists(c[name]), " [!] {c[name]} not exist." + assert os.path.exists(c[name]), f' [!] path for {name} ("{c[name]}") does not exist.' # skip the rest if the alternative field is defined. if alternative in c.keys() and c[alternative] is not None: return diff --git a/setup.py b/setup.py index 731b201..395e653 100644 --- a/setup.py +++ b/setup.py @@ -6,9 +6,10 @@ import setuptools.command.develop from setuptools import find_packages, setup -version = '0.0.6.6' +version = "0.0.7" cwd = os.path.dirname(os.path.abspath(__file__)) + class build_py(setuptools.command.build_py.build_py): # pylint: disable=too-many-ancestors def run(self): self.create_version_file() @@ -16,53 +17,54 @@ def run(self): @staticmethod def create_version_file(): - print('-- Building version ' + version) - version_path = os.path.join(cwd, 'version.py') - with open(version_path, 'w') as f: + print("-- Building version " + version) + version_path = os.path.join(cwd, "version.py") + with open(version_path, "w") as f: f.write("__version__ = '{}'\n".format(version)) + class develop(setuptools.command.develop.develop): def run(self): build_py.create_version_file() setuptools.command.develop.develop.run(self) -requirements = open(os.path.join(cwd, 'requirements.txt'), 'r').readlines() -with open('README.md', "r", encoding="utf-8") as readme_file: +requirements = open(os.path.join(cwd, "requirements.txt"), "r").readlines() +with open("README.md", "r", encoding="utf-8") as readme_file: README = readme_file.read() setup( - name='coqpit', + name="coqpit", version=version, - url='https://github.com/erogol/coqpit', - author='Eren Gölge', - author_email='egolge@coqui.ai', - description='Simple (maybe too simple), light-weight config management through python data-classes.', + url="https://github.com/erogol/coqpit", + author="Eren Gölge", + author_email="egolge@coqui.ai", + description="Simple (maybe too simple), light-weight config management through python data-classes.", long_description=README, long_description_content_type="text/markdown", - license='', + license="", include_package_data=True, - packages=find_packages(include=['coqpit*']), + packages=find_packages(include=["coqpit*"]), project_urls={ - 'Tracker': 'https://github.com/coqui-ai/coqpit/issues', - 'Repository': 'https://github.com/coqui-ai/coqpit', - 'Discussions': 'https://github.com/coqui-ai/coqpit/discussions', + "Tracker": "https://github.com/coqui-ai/coqpit/issues", + "Repository": "https://github.com/coqui-ai/coqpit", + "Discussions": "https://github.com/coqui-ai/coqpit/discussions", }, cmdclass={ - 'build_py': build_py, - 'develop': develop, + "build_py": build_py, + "develop": develop, }, install_requires=requirements, - python_requires='>=3.6.0', + python_requires=">=3.6.0", classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 3", - 'Development Status :: 4 - Beta', + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Operating System :: POSIX :: Linux", "Operating System :: MacOS", - "Operating System :: Microsoft :: Windows" + "Operating System :: Microsoft :: Windows", ], - zip_safe=False + zip_safe=False, ) diff --git a/tests/test_parse_argparse.py b/tests/test_parse_argparse.py index 16b2900..a6faa78 100644 --- a/tests/test_parse_argparse.py +++ b/tests/test_parse_argparse.py @@ -21,6 +21,9 @@ class SimpleConfig(Coqpit): ) empty_int_list: List[int] = field(default=None, metadata={"help": "int list without default value"}) empty_str_list: List[str] = field(default=None, metadata={"help": "str list without default value"}) + list_with_default_factory: List[str] = field( + default_factory=list, metadata={"help": "str list with default factory"} + ) # mylist_without_default: List[SimplerConfig] = field(default=None, metadata={'help': 'list of SimplerConfig'}) # NOT SUPPORTED YET! @@ -44,6 +47,7 @@ def test_parse_argparse(): args.extend(["--coqpit.mylist_with_default.1.val_a", "111"]) args.extend(["--coqpit.empty_int_list", "111", "222", "333"]) args.extend(["--coqpit.empty_str_list", "[foo=bar]", "[baz=qux]", "[blah,p=0.5,r=1~3]"]) + args.extend(["--coqpit.list_with_default_factory", "blah"]) # initial config config = SimpleConfig() @@ -58,6 +62,7 @@ def test_parse_argparse(): mylist_with_default=[SimplerConfig(val_a=222), SimplerConfig(val_a=111)], empty_int_list=[111, 222, 333], empty_str_list=["[foo=bar]", "[baz=qux]", "[blah,p=0.5,r=1~3]"], + list_with_default_factory=["blah"], ) # create and init argparser with Coqpit