From 42696e41780f05d9022d9f9fdf08be249005a33a Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Mon, 5 Jun 2023 19:09:04 -0700 Subject: [PATCH 01/45] convert library to compatibility w pydantic v2, start fixing tests --- confection/__init__.py | 33 ++++++++++++++++----------------- confection/tests/test_config.py | 9 ++++++--- requirements.txt | 2 +- setup.cfg | 2 +- 4 files changed, 24 insertions(+), 22 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 758c4f8..31e8a93 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -8,8 +8,7 @@ from configparser import ParsingError from pathlib import Path from pydantic import BaseModel, create_model, ValidationError, Extra -from pydantic.main import ModelMetaclass -from pydantic.fields import ModelField +from pydantic.fields import FieldInfo import srsly import catalogue import inspect @@ -668,13 +667,13 @@ def alias_generator(name: str) -> str: return name -def copy_model_field(field: ModelField, type_: Any) -> ModelField: +def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: """Copy a model field and assign a new type, e.g. to accept an Any type even though the original value is typed differently. """ - return ModelField( + return FieldInfo( name=field.name, - type_=type_, + annotation=type_, class_validators=field.class_validators, model_config=field.model_config, default=field.default, @@ -829,12 +828,12 @@ def _fill( value = overrides[key_parent] config[key] = value if cls.is_promise(value): - if key in schema.__fields__ and not resolve: + if key in schema.model_fields and not resolve: # If we're not resolving the config, make sure that the field # expecting the promise is typed Any so it doesn't fail # validation if it doesn't receive the function return value - field = schema.__fields__[key] - schema.__fields__[key] = copy_model_field(field, Any) + field = schema.model_fields[key] + schema.model_fields[key] = copy_model_field(field, Any) promise_schema = cls.make_promise_schema(value, resolve=resolve) filled[key], validation[v_key], final[key] = cls._fill( value, @@ -869,10 +868,10 @@ def _fill( validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema - if key in schema.__fields__: - field = schema.__fields__[key] - field_type = field.type_ - if not isinstance(field.type_, ModelMetaclass): + if key in schema.model_fields: + field = schema.model_fields[key] + field_type = field.annotation + if field_type is None or not issubclass(field_type, BaseModel): # If we don't have a pydantic schema and just a type field_type = EmptySchema filled[key], validation[v_key], final[key] = cls._fill( @@ -900,21 +899,21 @@ def _fill( exclude = [] if validate: try: - result = schema.parse_obj(validation) + result = schema.model_validate(validation) except ValidationError as e: raise ConfigValidationError( config=config, errors=e.errors(), parent=parent ) from None else: # Same as parse_obj, but without validation - result = schema.construct(**validation) + result = schema.model_construct(**validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything if schema.Config.extra in (Extra.forbid, Extra.ignore): - fields = schema.__fields__.keys() - exclude = [k for k in result.__fields_set__ if k not in fields] + fields = schema.model_fields.keys() + exclude = [k for k in result.model_fields_set if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) - validation.update(result.dict(exclude=exclude_validation)) + validation.update(result.model_dump(exclude=exclude_validation)) filled, final = cls._update_from_parsed(validation, filled, final) if exclude: filled = {k: v for k, v in filled.items() if k not in exclude} diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 6b2c730..2b726b3 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -104,7 +104,7 @@ def test_invalidate_simple_config(): my_registry._fill(invalid_config, HelloIntsSchema) error = exc_info.value assert len(error.errors) == 1 - assert "type_error.integer" in error.error_types + assert "int_parsing" in error.error_types def test_invalidate_extra_args(): @@ -154,8 +154,8 @@ def test_parse_args(): def test_make_promise_schema(): schema = my_registry.make_promise_schema(good_catsie) - assert "evil" in schema.__fields__ - assert "cute" in schema.__fields__ + assert "evil" in schema.model_fields + assert "cute" in schema.model_fields def test_validate_promise(): @@ -234,12 +234,15 @@ class TestSchema(BaseModel): my_registry.resolve({"cfg": config}, schema=TestSchema) +@pytest.mark.skip("In Pydantic v2, int/float cannot be coerced to str so this test will fail.") def test_resolve_schema_coerced(): class TestBaseSchema(BaseModel): test1: str test2: bool test3: float + model_config = {"strict": False} + class TestSchema(BaseModel): cfg: TestBaseSchema diff --git a/requirements.txt b/requirements.txt index 3715d80..c2e582e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0 +pydantic==2.0b2 typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 # Development requirements diff --git a/setup.cfg b/setup.cfg index b9da89a..c7caf63 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ zip_safe = true include_package_data = true python_requires = >=3.6 install_requires = - pydantic>=1.7.4,!=1.8,!=1.8.1,<1.11.0 + pydantic==2.0b2 typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 From ccedeb05fe4d0269e37ce62bc731c2ba5de48ee9 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Mon, 5 Jun 2023 19:17:31 -0700 Subject: [PATCH 02/45] fix constr and model config access --- confection/__init__.py | 2 +- confection/tests/test_config.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 31e8a93..12615c4 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -909,7 +909,7 @@ def _fill( result = schema.model_construct(**validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything - if schema.Config.extra in (Extra.forbid, Extra.ignore): + if schema.model_config.get("extra", Extra.forbid) in (Extra.forbid, Extra.ignore): fields = schema.model_fields.keys() exclude = [k for k in result.model_fields_set if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 2b726b3..c1bece4 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -8,6 +8,7 @@ import pickle from pydantic import BaseModel, StrictFloat, PositiveInt, constr +from pydantic.fields import Field from pydantic.types import StrictBool from confection import ConfigValidationError, Config @@ -331,7 +332,7 @@ def test_validation_custom_types(): def complex_args( rate: StrictFloat, steps: PositiveInt = 10, # type: ignore - log_level: constr(regex="(DEBUG|INFO|WARNING|ERROR)") = "ERROR", + log_level: str = Field("ERROR", pattern="(DEBUG|INFO|WARNING|ERROR)"), ): return None From 05ce65b2556f1ea153b016ace710794afaad30bb Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Thu, 8 Jun 2023 09:38:18 -0400 Subject: [PATCH 03/45] add support for Pydantic models and dataclasses out of registered functions --- confection/__init__.py | 41 ++++++++++++++++++++++----------- confection/tests/test_config.py | 23 ++++++++---------- 2 files changed, 37 insertions(+), 27 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 12615c4..3358414 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -1,7 +1,7 @@ from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping -from typing import Iterable, Sequence, cast +from typing import Iterable, Sequence, Set, cast from types import GeneratorType -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH from configparser import InterpolationMissingOptionError, InterpolationSyntaxError from configparser import NoSectionError, NoOptionError, InterpolationDepthError @@ -671,15 +671,9 @@ def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: """Copy a model field and assign a new type, e.g. to accept an Any type even though the original value is typed differently. """ - return FieldInfo( - name=field.name, - annotation=type_, - class_validators=field.class_validators, - model_config=field.model_config, - default=field.default, - default_factory=field.default_factory, - required=field.required, - ) + field_info = copy.deepcopy(field) + field_info.annotation = type_ + return field_info class EmptySchema(BaseModel): @@ -807,6 +801,7 @@ def _fill( resolve: bool = True, parent: str = "", overrides: Dict[str, Dict[str, Any]] = {}, + resolved_object_keys: Set[str] = set() ) -> Tuple[ Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any] ]: @@ -842,6 +837,7 @@ def _fill( resolve=resolve, parent=key_parent, overrides=overrides, + resolved_object_keys=resolved_object_keys ) reg_name, func_name = cls.get_constructor(final[key]) args, kwargs = cls.parse_args(final[key]) @@ -853,6 +849,9 @@ def _fill( # We don't want to try/except this and raise our own error # here, because we want the traceback if the function fails. getter_result = getter(*args, **kwargs) + + if isinstance(getter_result, BaseModel) or is_dataclass(getter_result): + resolved_object_keys.add(key) else: # We're not resolving and calling the function, so replace # the getter_result with a Promise class @@ -909,11 +908,27 @@ def _fill( result = schema.model_construct(**validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything - if schema.model_config.get("extra", Extra.forbid) in (Extra.forbid, Extra.ignore): + if schema.model_config.get("extra", "forbid") in ("forbid", "ignore"): fields = schema.model_fields.keys() exclude = [k for k in result.model_fields_set if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) - validation.update(result.model_dump(exclude=exclude_validation)) + # Do a shallow serialization first + # If any of the sub-objects are Pydantic models, first check if they + # were resolved earlier from a registry. If they weren't resolved + # they are part of a nested schema and need to be serialized with + # model.dict() + # Allows for returning Pydantic models from a registered function + shallow_result_dict = dict(result) + if result.model_extra is not None: + shallow_result_dict.update(result.model_extra) + result_dict = {} + for k, v in shallow_result_dict.items(): + if k in exclude_validation: + continue + result_dict[k] = v + if isinstance(v, BaseModel) and k not in resolved_object_keys: + result_dict[k] = v.model_dump() + validation.update(result_dict) filled, final = cls._update_from_parsed(validation, filled, final) if exclude: filled = {k: v for k, v in filled.items() if k not in exclude} diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index c1bece4..2e7f458 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -66,16 +66,14 @@ class HelloIntsSchema(BaseModel): hello: int world: int - class Config: - extra = "forbid" + model_config = {"extra": "forbid"} class DefaultsSchema(BaseModel): required: int optional: str = "default value" - class Config: - extra = "forbid" + model_config = {"extra": "forbid"} class ComplexSchema(BaseModel): @@ -217,8 +215,7 @@ class TestBaseSchema(BaseModel): one: PositiveInt two: TestBaseSubSchema - class Config: - extra = "forbid" + model_config = {"extra": "forbid"} class TestSchema(BaseModel): cfg: TestBaseSchema @@ -1208,8 +1205,7 @@ class TestSchemaContent(BaseModel): a: str b: int - class Config: - extra = "forbid" + model_config = {"extra": "forbid"} class TestSchema(BaseModel): cfg: TestSchemaContent @@ -1230,8 +1226,7 @@ class TestSchemaContent2(BaseModel): a: str b: int - class Config: - extra = "allow" + model_config = {"extra": "allow"} class TestSchema2(BaseModel): cfg: TestSchemaContent2 @@ -1255,9 +1250,9 @@ class Schema(BaseModel): assert e1.show_config is True assert len(e1.errors) == 1 assert e1.errors[0]["loc"] == ("world",) - assert e1.errors[0]["msg"] == "value is not a valid integer" - assert e1.errors[0]["type"] == "type_error.integer" - assert e1.error_types == set(["type_error.integer"]) + assert e1.errors[0]["msg"] == "Input should be a valid integer, unable to parse string as an integer" + assert e1.errors[0]["type"] == "int_parsing" + assert e1.error_types == {"int_parsing"} # Create a new error with overrides title = "Custom error" desc = "Some error description here" @@ -1287,7 +1282,7 @@ class BaseSchema(BaseModel): assert filled["catsie"]["cute"] is True with pytest.raises(ConfigValidationError): my_registry.resolve(config, schema=BaseSchema) - filled2 = my_registry.fill(config, schema=BaseSchema) + filled2 = my_registry.fill(config, schema=BaseSchema, validate=False) assert filled2["catsie"]["cute"] is True resolved = my_registry.resolve(filled2) assert resolved["catsie"] == "meow" From a822d3c9fa15869bbd6bd553a2bb2ed4e12e3b08 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 12:00:52 -0700 Subject: [PATCH 04/45] update reqs --- requirements.txt | 4 ++-- setup.cfg | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index c2e582e..47c6863 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -pydantic==2.0b2 -typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" +pydantic>=2.0,<2.1 +typing_extensions>=4.6.1; python_version < "3.8" srsly>=2.4.0,<3.0.0 # Development requirements pathy>=0.3.5 diff --git a/setup.cfg b/setup.cfg index c7caf63..900541d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,7 @@ include_package_data = true python_requires = >=3.6 install_requires = pydantic==2.0b2 - typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" + typing_extensions>=4.6.1,<5.0.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 [sdist] From 59f3f55a3111f22076348169f95aee0c1a25a7c9 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 12:01:18 -0700 Subject: [PATCH 05/45] start compat --- confection/__init__.py | 77 +++++++++++++++++++++++---------- confection/tests/test_2.py | 36 +++++++++++++++ confection/tests/test_config.py | 6 +-- 3 files changed, 94 insertions(+), 25 deletions(-) create mode 100644 confection/tests/test_2.py diff --git a/confection/__init__.py b/confection/__init__.py index 3358414..a590085 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -1,6 +1,7 @@ from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping -from typing import Iterable, Sequence, Set, cast +from typing import Iterable, Sequence, Set, TypeVar, cast from types import GeneratorType +from inspect import isclass from dataclasses import dataclass, is_dataclass from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH from configparser import InterpolationMissingOptionError, InterpolationSyntaxError @@ -9,6 +10,7 @@ from pathlib import Path from pydantic import BaseModel, create_model, ValidationError, Extra from pydantic.fields import FieldInfo +from pydantic.version import VERSION as PYDANTIC_VERSION import srsly import catalogue import inspect @@ -34,6 +36,8 @@ # Regex to detect whether a value contains a variable VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}") +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") + class CustomInterpolation(ExtendedInterpolation): def before_read(self, parser, section, option, value): @@ -667,15 +671,43 @@ def alias_generator(name: str) -> str: return name -def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: +def copy_model_field(field: Union["FieldInfo", "ModelField"], type_: Type) -> Union["FieldInfo", "ModelField"]: """Copy a model field and assign a new type, e.g. to accept an Any type even though the original value is typed differently. """ field_info = copy.deepcopy(field) - field_info.annotation = type_ + if PYDANTIC_V2: + field_info.annotation = type_ + else: + field_info.type_ = type_ return field_info +def get_model_config_extra(model: Type[BaseModel]) -> str: + if PYDANTIC_V2: + extra = model.model_config.get("extra", "forbid") + else: + extra = str(model.Config.extra) or "forbid" + assert isinstance(extra, str) + return extra + + + +_ModelT = TypeVar("_ModelT", bound=BaseModel) + + +def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: + return Schema.model_validate(**data) if PYDANTIC_V2 else Schema(**data) + + +def model_construct(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: + return Schema.model_construct(**data) if PYDANTIC_V2 else Schema.construct(**data) + + +def model_dump(instance: BaseModel) -> Dict[str, Any]: + return instance.model_dump() if PYDANTIC_V2 else instance.dict() + + class EmptySchema(BaseModel): class Config: extra = "allow" @@ -860,17 +892,18 @@ def _fill( ) validation[v_key] = getter_result final[key] = getter_result - if isinstance(validation[v_key], GeneratorType): - # If value is a generator we can't validate type without - # consuming it (which doesn't work if it's infinite – see - # schedule for examples). So we skip it. - validation[v_key] = [] + # if isinstance(validation[v_key], GeneratorType): + # # If value is a generator we can't validate type without + # # consuming it (which doesn't work if it's infinite – see + # # schedule for examples). So we skip it. + # validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema - if key in schema.model_fields: - field = schema.model_fields[key] - field_type = field.annotation - if field_type is None or not issubclass(field_type, BaseModel): + fields = schema.model_fields if PYDANTIC_V2 else schema.__fields__ + if key in fields: + field = fields[key] + field_type = field.annotation if PYDANTIC_V2 else field.type_ + if field_type is None or not (isclass(field_type) and issubclass(field_type, BaseModel)): # If we don't have a pydantic schema and just a type field_type = EmptySchema filled[key], validation[v_key], final[key] = cls._fill( @@ -888,29 +921,29 @@ def _fill( final[key] = list(final[key].values()) else: filled[key] = value - # Prevent pydantic from consuming generator if part of a union - validation[v_key] = ( - value if not isinstance(value, GeneratorType) else [] - ) + validation[v_key] = value final[key] = value # Now that we've filled in all of the promises, update with defaults # from schema, and validate if validation is enabled exclude = [] if validate: try: - result = schema.model_validate(validation) + result = schema.model_validate(validation) if PYDANTIC_V2 else schema(**validation) except ValidationError as e: + raise ConfigValidationError( config=config, errors=e.errors(), parent=parent ) from None else: # Same as parse_obj, but without validation - result = schema.model_construct(**validation) + result = schema.model_construct(**validation) if PYDANTIC_V2 else schema.construct(**validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything - if schema.model_config.get("extra", "forbid") in ("forbid", "ignore"): - fields = schema.model_fields.keys() - exclude = [k for k in result.model_fields_set if k not in fields] + extra_attr = get_model_config_extra(schema) + if extra_attr in ("forbid", "ignore"): + fields = schema.model_fields.keys() if PYDANTIC_V2 else schema.__fields__.keys() + result_fields_set = result.model_fields_set if PYDANTIC_V2 else result.__fields_set__ + exclude = [k for k in result_fields_set if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) # Do a shallow serialization first # If any of the sub-objects are Pydantic models, first check if they @@ -927,7 +960,7 @@ def _fill( continue result_dict[k] = v if isinstance(v, BaseModel) and k not in resolved_object_keys: - result_dict[k] = v.model_dump() + result_dict[k] = model_dump(v) validation.update(result_dict) filled, final = cls._update_from_parsed(validation, filled, final) if exclude: diff --git a/confection/tests/test_2.py b/confection/tests/test_2.py new file mode 100644 index 0000000..290ecd7 --- /dev/null +++ b/confection/tests/test_2.py @@ -0,0 +1,36 @@ +import dataclasses +from typing import Union, Iterable +import catalogue +from confection import registry, Config +from pydantic import BaseModel + +# Create a new registry. +registry.optimizers = catalogue.create("confection", "optimizers", entry_points=False) + + +# Define a dummy optimizer class. + +@dataclasses.dataclass +class MyCoolOptimizer: + learn_rate: float + gamma: float + + +@registry.optimizers.register("my_cool_optimizer.v1") +def make_my_optimizer(learn_rate: Union[float, Iterable[float]], gamma: float): + return MyCoolOptimizer(learn_rate=learn_rate, gamma=gamma) + + +if __name__ == "__main__": + # Load the config file from disk, resolve it and fetch the instantiated optimizer object. + cfg_str = """ +[optimizer] +@optimizers = "my_cool_optimizer.v1" +learn_rate = 0.001 +gamma = 1e-8 + """ + config = Config().from_str(cfg_str) + resolved = registry.resolve(config) + optimizer = resolved["optimizer"] # MyCoolOptimizer(learn_rate=0.001, gamma=1e-08) + + print(config, resolved, optimizer) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 2e7f458..587b759 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -584,7 +584,7 @@ def test_schedule(): assert isinstance(result, GeneratorType) @my_registry.optimizers("test_optimizer.v2") - def test_optimizer2(rate: Generator) -> Generator: + def test_optimizer2(rate: Iterable[float]) -> Iterable[float]: return rate cfg = { @@ -595,7 +595,7 @@ def test_optimizer2(rate: Generator) -> Generator: assert isinstance(result, GeneratorType) @my_registry.optimizers("test_optimizer.v3") - def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: + def test_optimizer3(schedules: Dict[str, Iterable[float]]) -> Iterable[float]: return schedules["rate"] cfg = { @@ -606,7 +606,7 @@ def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: assert isinstance(result, GeneratorType) @my_registry.optimizers("test_optimizer.v4") - def test_optimizer4(*schedules: Generator) -> Generator: + def test_optimizer4(*schedules: Iterable[float]) -> Iterable[float]: return schedules[0] From 3b51749753ffdfe3e8ce0ab0eaea8b1d86e31ae8 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 12:48:59 -0700 Subject: [PATCH 06/45] small corrrections around new model_construct behavior --- confection/__init__.py | 27 +++++++++++++++------------ confection/tests/test_config.py | 8 +++++--- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index a8c1a32..e21900a 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -675,15 +675,17 @@ def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: class EmptySchema(BaseModel): - class Config: - extra = "allow" - arbitrary_types_allowed = True + model_config = { + "extra": "allow", + "arbitrary_types_allowed": True + } -class _PromiseSchemaConfig: - extra = "forbid" - arbitrary_types_allowed = True - alias_generator = alias_generator +_promise_schema_config = { + "extra": "forbid", + "arbitrary_types_allowed": True, + "alias_generator": alias_generator +} @dataclass @@ -902,13 +904,14 @@ def _fill( config=config, errors=e.errors(), parent=parent ) from None else: - # Same as parse_obj, but without validation - result = schema.model_construct(**validation) + # Same as model_validate, but without validation + fields_set = set(schema.model_fields.keys()) + result = schema.model_construct(fields_set, **validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything if schema.model_config.get("extra", "forbid") in ("forbid", "ignore"): - fields = schema.model_fields.keys() - exclude = [k for k in result.model_fields_set if k not in fields] + fields = result.model_fields_set + exclude = [k for k in dict(result).keys() if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) # Do a shallow serialization first # If any of the sub-objects are Pydantic models, first check if they @@ -1055,7 +1058,7 @@ def make_promise_schema( else: name = RESERVED_FIELDS.get(param.name, param.name) sig_args[name] = (annotation, default) - sig_args["__config__"] = _PromiseSchemaConfig + sig_args["__config__"] = _promise_schema_config return create_model("ArgModel", **sig_args) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 57b56fd..876a50c 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -7,7 +7,7 @@ from types import GeneratorType import pickle -from pydantic import BaseModel, StrictFloat, PositiveInt, constr +from pydantic import BaseModel, StrictFloat, PositiveInt from pydantic.fields import Field from pydantic.types import StrictBool @@ -1205,7 +1205,9 @@ class TestSchemaContent(BaseModel): a: str b: int - model_config = {"extra": "forbid"} + model_config = { + "extra": "forbid", + } class TestSchema(BaseModel): cfg: TestSchemaContent @@ -1282,7 +1284,7 @@ class BaseSchema(BaseModel): assert filled["catsie"]["cute"] is True with pytest.raises(ConfigValidationError): my_registry.resolve(config, schema=BaseSchema) - filled2 = my_registry.fill(config, schema=BaseSchema, validate=False) + filled2 = my_registry.fill(config, schema=BaseSchema) assert filled2["catsie"]["cute"] is True resolved = my_registry.resolve(filled2) assert resolved["catsie"] == "meow" From 2df560fd5fa05685cc64d246e4765a79b60826f2 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 12:58:14 -0700 Subject: [PATCH 07/45] use Iterator instead of Generator and GeneratorType --- confection/__init__.py | 20 ++++++++------------ confection/tests/test_config.py | 15 +++++++-------- confection/tests/util.py | 2 +- 3 files changed, 16 insertions(+), 21 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index e21900a..5975b14 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -1,6 +1,5 @@ from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping -from typing import Iterable, Sequence, Set, cast -from types import GeneratorType +from typing import Iterable, Iterator, Sequence, Set, cast from dataclasses import dataclass, is_dataclass from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH from configparser import InterpolationMissingOptionError, InterpolationSyntaxError @@ -674,6 +673,10 @@ def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: return field_info +def _safe_is_subclass(cls: type, expected: type) -> bool: + return inspect.isclass(cls) and issubclass(cls, BaseModel) + + class EmptySchema(BaseModel): model_config = { "extra": "allow", @@ -860,17 +863,12 @@ def _fill( ) validation[v_key] = getter_result final[key] = getter_result - if isinstance(validation[v_key], GeneratorType): - # If value is a generator we can't validate type without - # consuming it (which doesn't work if it's infinite – see - # schedule for examples). So we skip it. - validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema if key in schema.model_fields: field = schema.model_fields[key] field_type = field.annotation - if field_type is None or not issubclass(field_type, BaseModel): + if field_type is None or not _safe_is_subclass(field_type, BaseModel): # If we don't have a pydantic schema and just a type field_type = EmptySchema filled[key], validation[v_key], final[key] = cls._fill( @@ -889,9 +887,7 @@ def _fill( else: filled[key] = value # Prevent pydantic from consuming generator if part of a union - validation[v_key] = ( - value if not isinstance(value, GeneratorType) else [] - ) + validation[v_key] = value final[key] = value # Now that we've filled in all of the promises, update with defaults # from schema, and validate if validation is enabled @@ -965,7 +961,7 @@ def _update_from_parsed( final[key] = value elif ( value != final[key] or not isinstance(type(value), type(final[key])) - ) and not isinstance(final[key], GeneratorType): + ): final[key] = value return filled, final diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 876a50c..413e804 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -3,8 +3,7 @@ import catalogue import pytest -from typing import Dict, Optional, Iterable, Callable, Any, Union, List, Tuple -from types import GeneratorType +from typing import Dict, Optional, Iterator, Iterable, Callable, Any, Union, List, Tuple import pickle from pydantic import BaseModel, StrictFloat, PositiveInt @@ -581,10 +580,10 @@ def test_schedule(): cfg = {"@schedules": "test_schedule.v2"} result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) + assert isinstance(result, Iterator) @my_registry.optimizers("test_optimizer.v2") - def test_optimizer2(rate: Generator) -> Generator: + def test_optimizer2(rate: Iterator) -> Iterator: return rate cfg = { @@ -592,10 +591,10 @@ def test_optimizer2(rate: Generator) -> Generator: "rate": {"@schedules": "test_schedule.v2"}, } result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) + assert isinstance(result, Iterator) @my_registry.optimizers("test_optimizer.v3") - def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: + def test_optimizer3(schedules: Dict[str, Iterator]) -> Iterator: return schedules["rate"] cfg = { @@ -603,10 +602,10 @@ def test_optimizer3(schedules: Dict[str, Generator]) -> Generator: "schedules": {"rate": {"@schedules": "test_schedule.v2"}}, } result = my_registry.resolve({"test": cfg})["test"] - assert isinstance(result, GeneratorType) + assert isinstance(result, Iterator) @my_registry.optimizers("test_optimizer.v4") - def test_optimizer4(*schedules: Generator) -> Generator: + def test_optimizer4(*schedules: Iterator) -> Iterator: return schedules[0] diff --git a/confection/tests/util.py b/confection/tests/util.py index 1f4c56f..112e998 100644 --- a/confection/tests/util.py +++ b/confection/tests/util.py @@ -20,7 +20,7 @@ import catalogue import confection -FloatOrSeq = Union[float, List[float], Generator] +FloatOrSeq = Union[float, Iterable[float]] InT = TypeVar("InT") OutT = TypeVar("OutT") From 5213981266cc4a77084b1e810aeaca23125a43b6 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 12:59:15 -0700 Subject: [PATCH 08/45] don't validate in fill_without_resolve test --- confection/tests/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 413e804..9914055 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -1283,7 +1283,7 @@ class BaseSchema(BaseModel): assert filled["catsie"]["cute"] is True with pytest.raises(ConfigValidationError): my_registry.resolve(config, schema=BaseSchema) - filled2 = my_registry.fill(config, schema=BaseSchema) + filled2 = my_registry.fill(config, schema=BaseSchema, validate=False) assert filled2["catsie"]["cute"] is True resolved = my_registry.resolve(filled2) assert resolved["catsie"] == "meow" From ff3b55f1a8a859060143c5c0962a28e62ac63afd Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 13:00:43 -0700 Subject: [PATCH 09/45] bump reqs --- requirements.txt | 4 ++-- setup.cfg | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/requirements.txt b/requirements.txt index c2e582e..d54da9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -pydantic==2.0b2 -typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" +pydantic>=2.0,<2.1 +typing_extensions>=4.6.1,<5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 # Development requirements pathy>=0.3.5 diff --git a/setup.cfg b/setup.cfg index 60a3f92..0849d1b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,8 +30,8 @@ zip_safe = true include_package_data = true python_requires = >=3.6 install_requires = - pydantic==2.0b2 - typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" + pydantic>=2.0,<2.1 + typing_extensions>=4.6.1,<5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 [sdist] From 19994776871a10a2176ee4753f168c214ae637a7 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 13:09:06 -0700 Subject: [PATCH 10/45] refactor and fix for mypy --- confection/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 5975b14..8775915 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -867,10 +867,8 @@ def _fill( field_type = EmptySchema if key in schema.model_fields: field = schema.model_fields[key] - field_type = field.annotation - if field_type is None or not _safe_is_subclass(field_type, BaseModel): - # If we don't have a pydantic schema and just a type - field_type = EmptySchema + if field.annotation is not None and _safe_is_subclass(field.annotation, BaseModel): + field_type = field.annotation filled[key], validation[v_key], final[key] = cls._fill( value, field_type, From 45f12bab454e760217a696df4342638baf9cff59 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 13:11:09 -0700 Subject: [PATCH 11/45] disable python 3.6 --- .github/workflows/tests.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0c167e2..1cd31d6 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -14,11 +14,6 @@ jobs: matrix: os: [ubuntu-latest, windows-latest, macos-latest] python_version: ["3.7", "3.8", "3.9", "3.10", "3.11"] - include: - - os: windows-2019 - python_version: "3.6" - - os: ubuntu-20.04 - python_version: "3.6" runs-on: ${{ matrix.os }} From efd17374ced710d378d9046bd0b7bad0905a69b5 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 13:24:41 -0700 Subject: [PATCH 12/45] rm extra python 3.6 ref --- .github/workflows/tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1cd31d6..0201b21 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,7 +35,6 @@ jobs: - name: Run mypy run: python -m mypy confection - if: matrix.python_version != '3.6' - name: Delete source directory run: rm -rf confection From ca99729924effddcd0b05c7a5471447a56b1a09d Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 13:30:36 -0700 Subject: [PATCH 13/45] check that pydantic and dataclass versions of Optimizer both work --- confection/tests/test_config.py | 55 ++++++++++++++++++++++----------- confection/tests/util.py | 26 +++++++++++++++- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 9914055..61b955f 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -46,7 +46,7 @@ """ -OPTIMIZER_CFG = """ +OPTIMIZER_DATACLASS_CFG = """ [optimizer] @optimizers = "Adam.v1" beta1 = 0.9 @@ -61,6 +61,21 @@ """ +OPTIMIZER_PYDANTIC_CFG = """ +[optimizer] +@optimizers = "Adam.pydantic.v1" +beta1 = 0.9 +beta2 = 0.999 +use_averages = true + +[optimizer.learn_rate] +@schedules = "warmup_linear.v1" +initial_rate = 0.1 +warmup_steps = 10000 +total_steps = 100000 +""" + + class HelloIntsSchema(BaseModel): hello: int world: int @@ -261,17 +276,19 @@ def test_read_config(): assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128 -def test_optimizer_config(): - cfg = Config().from_str(OPTIMIZER_CFG) +@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) +def test_optimizer_config(optimizer_cfg_str: str): + cfg = Config().from_str(optimizer_cfg_str) optimizer = my_registry.resolve(cfg, validate=True)["optimizer"] assert optimizer.beta1 == 0.9 -def test_config_to_str(): - cfg = Config().from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() - cfg = Config({"optimizer": {"foo": "bar"}}).from_str(OPTIMIZER_CFG) - assert cfg.to_str().strip() == OPTIMIZER_CFG.strip() +@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) +def test_config_to_str(optimizer_cfg_str: str): + cfg = Config().from_str(optimizer_cfg_str) + assert cfg.to_str().strip() == optimizer_cfg_str.strip() + cfg = Config({"optimizer": {"foo": "bar"}}).from_str(optimizer_cfg_str) + assert cfg.to_str().strip() == optimizer_cfg_str.strip() def test_config_to_str_creates_intermediate_blocks(): @@ -287,28 +304,30 @@ def test_config_to_str_creates_intermediate_blocks(): ) -def test_config_roundtrip_bytes(): - cfg = Config().from_str(OPTIMIZER_CFG) +@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) +def test_config_roundtrip_bytes(optimizer_cfg_str: str): + cfg = Config().from_str(optimizer_cfg_str) cfg_bytes = cfg.to_bytes() new_cfg = Config().from_bytes(cfg_bytes) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() + assert new_cfg.to_str().strip() == optimizer_cfg_str.strip() -def test_config_roundtrip_disk(): - cfg = Config().from_str(OPTIMIZER_CFG) +@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) +def test_config_roundtrip_disk(optimizer_cfg_str: str): + cfg = Config().from_str(optimizer_cfg_str) with make_tempdir() as path: cfg_path = path / "config.cfg" cfg.to_disk(cfg_path) new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() - + assert new_cfg.to_str().strip() == optimizer_cfg_str.strip() -def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture): - cfg = Config().from_str(OPTIMIZER_CFG) +@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) +def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture, optimizer_cfg_str: str): + cfg = Config().from_str(optimizer_cfg_str) cfg_path = pathy_fixture / "config.cfg" cfg.to_disk(cfg_path) new_cfg = Config().from_disk(cfg_path) - assert new_cfg.to_str().strip() == OPTIMIZER_CFG.strip() + assert new_cfg.to_str().strip() == optimizer_cfg_str.strip() def test_config_to_str_invalid_defaults(): diff --git a/confection/tests/util.py b/confection/tests/util.py index 112e998..9613154 100644 --- a/confection/tests/util.py +++ b/confection/tests/util.py @@ -10,11 +10,11 @@ Iterable, List, Union, - Generator, Generic, TypeVar, Optional, ) +from pydantic import BaseModel from pydantic.types import StrictBool import catalogue @@ -92,6 +92,30 @@ class Optimizer: ) +@my_registry.optimizers("Adam.pydantic.v1") +def Adam_pydantic( + learn_rate: FloatOrSeq = 0.001, + *, + beta1: FloatOrSeq = 0.001, + beta2: FloatOrSeq = 0.001, + use_averages: bool = True, +): + """ + Mocks optimizer generation. Note that the returned object is not actually an optimizer. This function is merely used + to illustrate how to use the function registry, e.g. with thinc. + """ + + class Optimizer(BaseModel): + learn_rate: FloatOrSeq + beta1: FloatOrSeq + beta2: FloatOrSeq + use_averages: bool + + return Optimizer( + learn_rate=learn_rate, beta1=beta1, beta2=beta2, use_averages=use_averages + ) + + @my_registry.schedules("warmup_linear.v1") def warmup_linear( initial_rate: float, warmup_steps: int, total_steps: int From beb567e1823d09fc9193e9f25ae4b5f56974189b Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 17:18:49 -0700 Subject: [PATCH 14/45] fix conflict --- confection/tests/test_config.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 6e21762..8b99d99 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -602,11 +602,7 @@ def test_schedule(): assert isinstance(result, Iterator) @my_registry.optimizers("test_optimizer.v2") -<<<<<<< HEAD def test_optimizer2(rate: Iterable[float]) -> Iterable[float]: -======= - def test_optimizer2(rate: Iterator) -> Iterator: ->>>>>>> ca99729924effddcd0b05c7a5471447a56b1a09d return rate cfg = { @@ -617,11 +613,7 @@ def test_optimizer2(rate: Iterator) -> Iterator: assert isinstance(result, Iterator) @my_registry.optimizers("test_optimizer.v3") -<<<<<<< HEAD def test_optimizer3(schedules: Dict[str, Iterable[float]]) -> Iterable[float]: -======= - def test_optimizer3(schedules: Dict[str, Iterator]) -> Iterator: ->>>>>>> ca99729924effddcd0b05c7a5471447a56b1a09d return schedules["rate"] cfg = { @@ -632,11 +624,7 @@ def test_optimizer3(schedules: Dict[str, Iterator]) -> Iterator: assert isinstance(result, Iterator) @my_registry.optimizers("test_optimizer.v4") -<<<<<<< HEAD def test_optimizer4(*schedules: Iterable[float]) -> Iterable[float]: -======= - def test_optimizer4(*schedules: Iterator) -> Iterator: ->>>>>>> ca99729924effddcd0b05c7a5471447a56b1a09d return schedules[0] From 07870e7f630762a92d54b78df90ca49ecada3a69 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 30 Jun 2023 17:20:31 -0700 Subject: [PATCH 15/45] move back to old Config nested class --- confection/__init__.py | 18 ++++++++---------- confection/tests/test_config.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 9d1d7cb..9d465a5 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -713,17 +713,15 @@ def _safe_is_subclass(cls: type, expected: type) -> bool: class EmptySchema(BaseModel): - model_config = { - "extra": "allow", - "arbitrary_types_allowed": True - } + class Config: + extra = "allow" + arbitrary_types_allowed = True -_promise_schema_config = { - "extra": "forbid", - "arbitrary_types_allowed": True, - "alias_generator": alias_generator -} +class _PromiseSchemaConfig: + extra = "forbid" + arbitrary_types_allowed = True + alias_generator = alias_generator @dataclass @@ -1095,7 +1093,7 @@ def make_promise_schema( else: name = RESERVED_FIELDS.get(param.name, param.name) sig_args[name] = (annotation, default) - sig_args["__config__"] = _promise_schema_config + sig_args["__config__"] = _PromiseSchemaConfig return create_model("ArgModel", **sig_args) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 8b99d99..fe88ef7 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -80,14 +80,16 @@ class HelloIntsSchema(BaseModel): hello: int world: int - model_config = {"extra": "forbid"} + class Config: + extra = "forbid" class DefaultsSchema(BaseModel): required: int optional: str = "default value" - model_config = {"extra": "forbid"} + class Config: + extra = "forbid" class ComplexSchema(BaseModel): @@ -1223,9 +1225,8 @@ class TestSchemaContent(BaseModel): a: str b: int - model_config = { - "extra": "forbid", - } + class Config: + extra = "forbid" class TestSchema(BaseModel): cfg: TestSchemaContent @@ -1246,7 +1247,8 @@ class TestSchemaContent2(BaseModel): a: str b: int - model_config = {"extra": "allow"} + class Config: + extra = "allow" class TestSchema2(BaseModel): cfg: TestSchemaContent2 From 36bc36893d0136eeec69a39e36c417bc7b538b0e Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Sat, 1 Jul 2023 15:59:40 -0700 Subject: [PATCH 16/45] fix tests --- confection/__init__.py | 45 ++++++++++++++++++++++++--------- confection/tests/test_config.py | 21 ++++++++++----- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 9d465a5..1e6ffbc 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -1,5 +1,5 @@ from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping -from typing import Iterable, Sequence, Set, cast +from typing import Iterable, Sequence, Set, TypeVar, cast from dataclasses import dataclass, is_dataclass from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH from configparser import InterpolationMissingOptionError, InterpolationSyntaxError @@ -708,6 +708,25 @@ def get_field_annotation(field: FieldInfo) -> Type: return field.annotation if PYDANTIC_V2 else field.type_ +def get_model_fields(Schema: Union[Type[BaseModel], BaseModel]) -> Dict[str, FieldInfo]: + return Schema.model_fields if PYDANTIC_V2 else Schema.__fields__ + + +def get_model_fields_set(Schema: Union[Type[BaseModel], BaseModel]) -> Set[str]: + return Schema.model_fields_set if PYDANTIC_V2 else Schema.__fields_set__ + + +def get_model_extra(model: BaseModel) -> Dict[str, FieldInfo]: + return model.model_extra if PYDANTIC_V2 else {} + + +def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo): + if PYDANTIC_V2: + Schema.model_fields[key] = field + else: + Schema.__fields__[key] = field + + def _safe_is_subclass(cls: type, expected: type) -> bool: return inspect.isclass(cls) and issubclass(cls, BaseModel) @@ -859,12 +878,13 @@ def _fill( value = overrides[key_parent] config[key] = value if cls.is_promise(value): - if key in schema.model_fields and not resolve: + model_fields = get_model_fields(schema) + if key in model_fields and not resolve: # If we're not resolving the config, make sure that the field # expecting the promise is typed Any so it doesn't fail # validation if it doesn't receive the function return value - field = schema.model_fields[key] - schema.model_fields[key] = copy_model_field(field, Any) + field = model_fields[key] + set_model_field(schema, key, copy_model_field(field, Any)) promise_schema = cls.make_promise_schema(value, resolve=resolve) filled[key], validation[v_key], final[key] = cls._fill( value, @@ -903,11 +923,12 @@ def _fill( # validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema - fields = schema.model_fields if PYDANTIC_V2 else schema.__fields__ + fields = get_model_fields(schema) if key in fields: field = fields[key] - if field.annotation is not None and _safe_is_subclass(field.annotation, BaseModel): - field_type = field.annotation + annotation = get_field_annotation(field) + if annotation is not None and _safe_is_subclass(annotation, BaseModel): + field_type = annotation filled[key], validation[v_key], final[key] = cls._fill( value, field_type, @@ -932,7 +953,7 @@ def _fill( exclude = [] if validate: try: - result = schema.model_validate(validation) if PYDANTIC_V2 else schema(**validation) + result = model_validate(schema, validation) except ValidationError as e: raise ConfigValidationError( @@ -940,12 +961,12 @@ def _fill( ) from None else: # Same as model_validate, but without validation - fields_set = set(schema.model_fields.keys()) + fields_set = set(get_model_fields(schema).keys()) result = model_construct(schema, fields_set, validation) # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything if get_model_config_extra(schema) in ("forbid", "extra"): - fields = result.model_fields_set + fields = get_model_fields_set(result) exclude = [k for k in dict(result).keys() if k not in fields] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) # Do a shallow serialization first @@ -955,8 +976,8 @@ def _fill( # model.dict() # Allows for returning Pydantic models from a registered function shallow_result_dict = dict(result) - if result.model_extra is not None: - shallow_result_dict.update(result.model_extra) + # if result.model_extra is not None: + # shallow_result_dict.update(result.model_extra) result_dict = {} for k, v in shallow_result_dict.items(): if k in exclude_validation: diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index fe88ef7..660afcb 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -10,7 +10,7 @@ from pydantic.fields import Field from pydantic.types import StrictBool -from confection import ConfigValidationError, Config +from confection import ConfigValidationError, Config, get_model_fields, PYDANTIC_V2 from confection.util import Generator, partial from confection.tests.util import Cat, my_registry, make_tempdir @@ -76,6 +76,12 @@ """ +if PYDANTIC_V2: + INT_PARSING_ERROR_TYPE = "int_parsing" +else: + INT_PARSING_ERROR_TYPE = "type_error.integer" + + class HelloIntsSchema(BaseModel): hello: int world: int @@ -119,7 +125,7 @@ def test_invalidate_simple_config(): my_registry._fill(invalid_config, HelloIntsSchema) error = exc_info.value assert len(error.errors) == 1 - assert "int_parsing" in error.error_types + assert INT_PARSING_ERROR_TYPE in error.error_types def test_invalidate_extra_args(): @@ -169,8 +175,9 @@ def test_parse_args(): def test_make_promise_schema(): schema = my_registry.make_promise_schema(good_catsie) - assert "evil" in schema.model_fields - assert "cute" in schema.model_fields + model_fields = get_model_fields(schema) + assert "evil" in model_fields + assert "cute" in model_fields def test_validate_promise(): @@ -1272,9 +1279,9 @@ class Schema(BaseModel): assert e1.show_config is True assert len(e1.errors) == 1 assert e1.errors[0]["loc"] == ("world",) - assert e1.errors[0]["msg"] == "Input should be a valid integer, unable to parse string as an integer" - assert e1.errors[0]["type"] == "int_parsing" - assert e1.error_types == {"int_parsing"} + assert e1.errors[0]["type"] == INT_PARSING_ERROR_TYPE + assert e1.error_types == {INT_PARSING_ERROR_TYPE} + # Create a new error with overrides title = "Custom error" desc = "Some error description here" From 0b31287d3b9599df5b4cf9e8ede068985efb93ca Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 12:25:17 -0700 Subject: [PATCH 17/45] update from model_extra --- confection/__init__.py | 18 ++++++++++++------ confection/tests/test_config.py | 12 +++++++++--- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 1e6ffbc..a96f4f5 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -11,6 +11,7 @@ from pydantic.version import VERSION as PYDANTIC_VERSION import srsly import catalogue +from types import GeneratorType import inspect import io import copy @@ -727,6 +728,12 @@ def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo): Schema.__fields__[key] = field +def update_from_model_extra(shallow_result_dict: Dict[str, Any], result: BaseModel) -> None: + if PYDANTIC_V2: + if result.model_extra is not None: + shallow_result_dict.update(result.model_extra) + + def _safe_is_subclass(cls: type, expected: type) -> bool: return inspect.isclass(cls) and issubclass(cls, BaseModel) @@ -917,10 +924,10 @@ def _fill( validation[v_key] = getter_result final[key] = getter_result # if isinstance(validation[v_key], GeneratorType): - # # If value is a generator we can't validate type without - # # consuming it (which doesn't work if it's infinite – see - # # schedule for examples). So we skip it. - # validation[v_key] = [] + # If value is a generator we can't validate type without + # consuming it (which doesn't work if it's infinite – see + # schedule for examples). So we skip it. + # validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema fields = get_model_fields(schema) @@ -976,8 +983,7 @@ def _fill( # model.dict() # Allows for returning Pydantic models from a registered function shallow_result_dict = dict(result) - # if result.model_extra is not None: - # shallow_result_dict.update(result.model_extra) + update_from_model_extra(shallow_result_dict, result) result_dict = {} for k, v in shallow_result_dict.items(): if k in exclude_validation: diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 660afcb..fe805a8 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -285,7 +285,7 @@ def test_read_config(): assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128 -@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) +@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG], ids=["dataclasses", "pydantic"]) def test_optimizer_config(optimizer_cfg_str: str): cfg = Config().from_str(optimizer_cfg_str) optimizer = my_registry.resolve(cfg, validate=True)["optimizer"] @@ -353,10 +353,16 @@ def test_config_to_str_invalid_defaults(): def test_validation_custom_types(): + if PYDANTIC_V2: + log_field = Field("ERROR", pattern="(DEBUG|INFO|WARNING|ERROR)") + else: + log_field = Field("ERROR", regex="(DEBUG|INFO|WARNING|ERROR)") + + def complex_args( rate: StrictFloat, - steps: PositiveInt = 10, # type: ignore - log_level: str = Field("ERROR", pattern="(DEBUG|INFO|WARNING|ERROR)"), + steps: PositiveInt = 10, + log_level: str = log_field, ): return None From 2247265dfe1ead7cb6bf779dc18812cd830ede15 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 12:54:34 -0700 Subject: [PATCH 18/45] fix pydantic generator equals --- confection/__init__.py | 17 +++++++++-------- confection/tests/test_config.py | 5 +++-- confection/util.py | 32 +++++++++++++++++++------------- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index a96f4f5..3734586 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -8,7 +8,6 @@ from pathlib import Path from pydantic import BaseModel, create_model, ValidationError, Extra from pydantic.fields import FieldInfo -from pydantic.version import VERSION as PYDANTIC_VERSION import srsly import catalogue from types import GeneratorType @@ -18,7 +17,7 @@ import re import warnings -from .util import Decorator, SimpleFrozenDict, SimpleFrozenList +from .util import Decorator, SimpleFrozenDict, SimpleFrozenList, PYDANTIC_V2 # Field used for positional arguments, e.g. [section.*.xyz]. The alias is # required for the schema (shouldn't clash with user-defined arg names) @@ -35,7 +34,6 @@ # Regex to detect whether a value contains a variable VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}") -PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") class CustomInterpolation(ExtendedInterpolation): @@ -923,11 +921,11 @@ def _fill( ) validation[v_key] = getter_result final[key] = getter_result - # if isinstance(validation[v_key], GeneratorType): + if isinstance(validation[v_key], GeneratorType): # If value is a generator we can't validate type without # consuming it (which doesn't work if it's infinite – see # schedule for examples). So we skip it. - # validation[v_key] = [] + validation[v_key] = [] elif hasattr(value, "items"): field_type = EmptySchema fields = get_model_fields(schema) @@ -952,8 +950,9 @@ def _fill( else: filled[key] = value # Prevent pydantic from consuming generator if part of a union - # TODO: reset for v1 pydantic - validation[v_key] = value + validation[v_key] = ( + value if not isinstance(value, GeneratorType) else [] + ) final[key] = value # Now that we've filled in all of the promises, update with defaults # from schema, and validate if validation is enabled @@ -1025,9 +1024,11 @@ def _update_from_parsed( # Check numpy first, just in case. Use stringified type so that numpy dependency can be ditched. elif str(type(value)) == "": final[key] = value + elif isinstance(value, BaseModel) and isinstance(final[key], BaseModel): + final[key] = value elif ( value != final[key] or not isinstance(type(value), type(final[key])) - ): + ) and not isinstance(final[key], GeneratorType): final[key] = value return filled, final diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index fe805a8..61c4fa5 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -255,14 +255,15 @@ class TestSchema(BaseModel): my_registry.resolve({"cfg": config}, schema=TestSchema) -@pytest.mark.skip("In Pydantic v2, int/float cannot be coerced to str so this test will fail.") +@pytest.mark.skipif(PYDANTIC_V2, reason="In Pydantic v2, int/float cannot be coerced to str so this test will fail.") def test_resolve_schema_coerced(): class TestBaseSchema(BaseModel): test1: str test2: bool test3: float - model_config = {"strict": False} + class Config: + strict = False class TestSchema(BaseModel): cfg: TestBaseSchema diff --git a/confection/util.py b/confection/util.py index d204118..8139218 100644 --- a/confection/util.py +++ b/confection/util.py @@ -1,6 +1,8 @@ import functools import sys from typing import Any, Callable, Iterator, TypeVar +from pydantic.version import VERSION as PYDANTIC_VERSION + if sys.version_info < (3, 8): # Ignoring type for mypy to avoid "Incompatible import" error (https://github.com/python/mypy/issues/4427). @@ -9,6 +11,7 @@ from typing import Protocol _DIn = TypeVar("_DIn") +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") class Decorator(Protocol): @@ -33,21 +36,24 @@ def partial( return partial_func -class Generator(Iterator): - """Custom generator type. Used to annotate function arguments that accept - generators so they can be validated by pydantic (which doesn't support - iterators/iterables otherwise). - """ +if PYDANTIC_V2: + Generator = Iterator +else: + class Generator(Iterator): + """Custom generator type. Used to annotate function arguments that accept + generators so they can be validated by pydantic (which doesn't support + iterators/iterables otherwise). + """ - @classmethod - def __get_validators__(cls): - yield cls.validate + @classmethod + def __get_validators__(cls): + yield cls.validate - @classmethod - def validate(cls, v): - if not hasattr(v, "__iter__") and not hasattr(v, "__next__"): - raise TypeError("not a valid iterator") - return v + @classmethod + def validate(cls, v): + if not hasattr(v, "__iter__") and not hasattr(v, "__next__"): + raise TypeError("not a valid iterator") + return v DEFAULT_FROZEN_DICT_ERROR = ( From 712e0edb854e03f074ecf2ea558b8fe74775553b Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:00:02 -0700 Subject: [PATCH 19/45] fixes for organization --- confection/__init__.py | 9 ++++----- confection/util.py | 29 +++++++++++++---------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 3734586..28749c7 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -1,16 +1,16 @@ from typing import Union, Dict, Any, Optional, List, Tuple, Callable, Type, Mapping -from typing import Iterable, Sequence, Set, TypeVar, cast +from typing import Iterable, Sequence, Set, TypeVar, TYPE_CHECKING, cast +from types import GeneratorType from dataclasses import dataclass, is_dataclass from configparser import ConfigParser, ExtendedInterpolation, MAX_INTERPOLATION_DEPTH from configparser import InterpolationMissingOptionError, InterpolationSyntaxError from configparser import NoSectionError, NoOptionError, InterpolationDepthError from configparser import ParsingError from pathlib import Path -from pydantic import BaseModel, create_model, ValidationError, Extra +from pydantic import BaseModel, create_model, ValidationError from pydantic.fields import FieldInfo import srsly import catalogue -from types import GeneratorType import inspect import io import copy @@ -35,7 +35,6 @@ VARIABLE_RE = re.compile(r"\$\{[\w\.:]+\}") - class CustomInterpolation(ExtendedInterpolation): def before_read(self, parser, section, option, value): # If we're dealing with a quoted string as the interpolation value, @@ -666,7 +665,7 @@ def alias_generator(name: str) -> str: return name -def copy_model_field(field: Union["FieldInfo", "ModelField"], type_: Type) -> Union["FieldInfo", "ModelField"]: +def copy_model_field(field: FieldInfo, type_: Type) -> FieldInfo: """Copy a model field and assign a new type, e.g. to accept an Any type even though the original value is typed differently. """ diff --git a/confection/util.py b/confection/util.py index 8139218..ed08f06 100644 --- a/confection/util.py +++ b/confection/util.py @@ -36,24 +36,21 @@ def partial( return partial_func -if PYDANTIC_V2: - Generator = Iterator -else: - class Generator(Iterator): - """Custom generator type. Used to annotate function arguments that accept - generators so they can be validated by pydantic (which doesn't support - iterators/iterables otherwise). - """ +class Generator(Iterator): + """Custom generator type. Used to annotate function arguments that accept + generators so they can be validated by pydantic (which doesn't support + iterators/iterables otherwise). + """ - @classmethod - def __get_validators__(cls): - yield cls.validate + @classmethod + def __get_validators__(cls): + yield cls.validate - @classmethod - def validate(cls, v): - if not hasattr(v, "__iter__") and not hasattr(v, "__next__"): - raise TypeError("not a valid iterator") - return v + @classmethod + def validate(cls, v): + if not hasattr(v, "__iter__") and not hasattr(v, "__next__"): + raise TypeError("not a valid iterator") + return v DEFAULT_FROZEN_DICT_ERROR = ( From ee6b10c92918e0714a2101fc9fb71cc9100ec23a Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:05:32 -0700 Subject: [PATCH 20/45] allow pydantic v1/v2 in reqs/setup and test both in CI --- .github/workflows/tests.yml | 20 +++++++++++++++----- requirements.txt | 4 ++-- setup.cfg | 4 ++-- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0201b21..352b6e1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,8 +33,8 @@ jobs: python -m pip install -U -r requirements.txt python -m build --sdist - - name: Run mypy - run: python -m mypy confection + # - name: Run mypy + # run: python -m mypy confection - name: Delete source directory run: rm -rf confection @@ -54,10 +54,20 @@ jobs: - name: Test import run: python -c "import confection" -Werror - - name: Install test requirements - run: python -m pip install -U -r requirements.txt + - name: Install test requirements with Pydantic v1 + run: | + python -m pip install -U -r requirements.txt + python -m pip install -U pydantic==1.10.* + + - name: Run tests for Pydantic v1 + run: python -m pytest --pyargs confection -Werror + + - name: Install test requirements with Pydantic v2 + run: | + python -m pip install -U -r requirements.txt + python -m pip install -U pydantic - - name: Run tests + - name: Run tests for Pydantic v2 run: python -m pytest --pyargs confection -Werror - name: Test for import conflicts with hypothesis diff --git a/requirements.txt b/requirements.txt index d54da9b..0cb4ae3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -pydantic>=2.0,<2.1 -typing_extensions>=4.6.1,<5.0; python_version < "3.8" +pydantic>=1.7.4,!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0 +typing_extensions>=4.5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 # Development requirements pathy>=0.3.5 diff --git a/setup.cfg b/setup.cfg index 0849d1b..a4e4895 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,8 +30,8 @@ zip_safe = true include_package_data = true python_requires = >=3.6 install_requires = - pydantic>=2.0,<2.1 - typing_extensions>=4.6.1,<5.0; python_version < "3.8" + pydantic>=1.7.4,!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0 + typing_extensions>=4.5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 [sdist] From 81fa915fed9582d1b17e210e15dad80aefaaefe5 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:07:18 -0700 Subject: [PATCH 21/45] only run CI push to main, not other branches --- .github/workflows/tests.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 352b6e1..3d07073 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,6 +2,8 @@ name: tests on: push: + branches: + - main pull_request: types: [opened, synchronize, reopened, edited] From 0fb6858f1138a1d322e1001c08c3a7b26b11d672 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:20:52 -0700 Subject: [PATCH 22/45] fix issue with model_validate --- confection/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/confection/__init__.py b/confection/__init__.py index 28749c7..83c2641 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -691,7 +691,7 @@ def get_model_config_extra(model: Type[BaseModel]) -> str: def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: - return Schema.model_validate(**data) if PYDANTIC_V2 else Schema(**data) + return Schema.model_validate(data) if PYDANTIC_V2 else Schema(**data) def model_construct(Schema: Type[_ModelT], fields_set: Optional[Set[str]], data: Dict[str, Any]) -> _ModelT: From 04354f26769d3504e42ae445cb0be2e4014d1701 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:21:21 -0700 Subject: [PATCH 23/45] fix filter warnings for tests --- .github/workflows/tests.yml | 4 ++-- pyproject.toml | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3d07073..c76c8f0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,7 +62,7 @@ jobs: python -m pip install -U pydantic==1.10.* - name: Run tests for Pydantic v1 - run: python -m pytest --pyargs confection -Werror + run: python -m pytest confection - name: Install test requirements with Pydantic v2 run: | @@ -70,7 +70,7 @@ jobs: python -m pip install -U pydantic - name: Run tests for Pydantic v2 - run: python -m pytest --pyargs confection -Werror + run: python -m pytest confection - name: Test for import conflicts with hypothesis run: | diff --git a/pyproject.toml b/pyproject.toml index 40810cc..085b725 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,3 +3,9 @@ requires = [ "setuptools", ] build-backend = "setuptools.build_meta" + +[tool.pytest.ini_options] +filterwarnings = [ + "error", + "ignore:^.*Support for class-based `config` is deprecated, use ConfigDict instead.*:DeprecationWarning" +] From a5f2d5a6d3e4f239b6eb73d9a724613d54328e5b Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:27:37 -0700 Subject: [PATCH 24/45] try run ci --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c76c8f0..ab7c90a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,7 +62,7 @@ jobs: python -m pip install -U pydantic==1.10.* - name: Run tests for Pydantic v1 - run: python -m pytest confection + run: python -m pytest - name: Install test requirements with Pydantic v2 run: | @@ -70,7 +70,7 @@ jobs: python -m pip install -U pydantic - name: Run tests for Pydantic v2 - run: python -m pytest confection + run: python -m pytest - name: Test for import conflicts with hypothesis run: | From 21196e5b92ae4370d3670a8e44b227b2f9b523fe Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:31:51 -0700 Subject: [PATCH 25/45] try run ci --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ab7c90a..c90b582 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -62,7 +62,7 @@ jobs: python -m pip install -U pydantic==1.10.* - name: Run tests for Pydantic v1 - run: python -m pytest + run: python -m pytest --pyargs confection - name: Install test requirements with Pydantic v2 run: | @@ -70,7 +70,7 @@ jobs: python -m pip install -U pydantic - name: Run tests for Pydantic v2 - run: python -m pytest + run: python -m pytest --pyargs confection - name: Test for import conflicts with hypothesis run: | From 20974fd4e35a34d056cda01e0d5c9bf9f21733f1 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:33:07 -0700 Subject: [PATCH 26/45] smaller test matrix --- .github/workflows/tests.yml | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c90b582..a761d4a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,9 +13,21 @@ jobs: if: github.repository == 'explosion/confection' strategy: fail-fast: false + # matrix: + # os: [ubuntu-latest, windows-latest, macos-latest] + # python_version: ["3.7", "3.8", "3.9", "3.10", "3.11"] matrix: - os: [ubuntu-latest, windows-latest, macos-latest] - python_version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + include: + - python_version: '3.7' + os: ubuntu-latest + - python_version: '3.8' + os: windows-latest + - python_version: '3.11' + os: ubuntu-latest + - python_version: '3.11' + os: windows-latest + - python_version: '3.11' + os: macos-latest runs-on: ${{ matrix.os }} From 995b9af61afde1e6a0602114d2688a03d47d0963 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:37:13 -0700 Subject: [PATCH 27/45] print pydantic version before tests --- .github/workflows/tests.yml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a761d4a..7fe9910 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -74,7 +74,9 @@ jobs: python -m pip install -U pydantic==1.10.* - name: Run tests for Pydantic v1 - run: python -m pytest --pyargs confection + run: | + python -c "import pydantic; print(pydantic.VERSION)" + python -m pytest --pyargs confection - name: Install test requirements with Pydantic v2 run: | @@ -82,7 +84,9 @@ jobs: python -m pip install -U pydantic - name: Run tests for Pydantic v2 - run: python -m pytest --pyargs confection + run: | + python -c "import pydantic; print(pydantic.VERSION)" + python -m pytest --pyargs confection - name: Test for import conflicts with hypothesis run: | From 3acfe90aa2185c66a5513dd29869cf01190e0e12 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:43:53 -0700 Subject: [PATCH 28/45] fixes for mypy --- confection/__init__.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 83c2641..e275629 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -665,7 +665,7 @@ def alias_generator(name: str) -> str: return name -def copy_model_field(field: FieldInfo, type_: Type) -> FieldInfo: +def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: """Copy a model field and assign a new type, e.g. to accept an Any type even though the original value is typed differently. """ @@ -673,13 +673,13 @@ def copy_model_field(field: FieldInfo, type_: Type) -> FieldInfo: if PYDANTIC_V2: field_info.annotation = type_ else: - field_info.type_ = type_ + field_info.type_ = type_ # type: ignore return field_info def get_model_config_extra(model: Type[BaseModel]) -> str: if PYDANTIC_V2: - extra = model.model_config.get("extra", "forbid") + extra = str(model.model_config.get("extra", "forbid")) else: extra = str(model.Config.extra) or "forbid" assert isinstance(extra, str) @@ -703,26 +703,26 @@ def model_dump(instance: BaseModel) -> Dict[str, Any]: def get_field_annotation(field: FieldInfo) -> Type: - return field.annotation if PYDANTIC_V2 else field.type_ + return field.annotation if PYDANTIC_V2 else field.type_ # type: ignore def get_model_fields(Schema: Union[Type[BaseModel], BaseModel]) -> Dict[str, FieldInfo]: - return Schema.model_fields if PYDANTIC_V2 else Schema.__fields__ + return Schema.model_fields if PYDANTIC_V2 else Schema.__fields__ # type: ignore def get_model_fields_set(Schema: Union[Type[BaseModel], BaseModel]) -> Set[str]: - return Schema.model_fields_set if PYDANTIC_V2 else Schema.__fields_set__ + return Schema.model_fields_set if PYDANTIC_V2 else Schema.__fields_set__ # type: ignore def get_model_extra(model: BaseModel) -> Dict[str, FieldInfo]: - return model.model_extra if PYDANTIC_V2 else {} + return model.model_extra or {} if PYDANTIC_V2 else {} def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo): if PYDANTIC_V2: Schema.model_fields[key] = field else: - Schema.__fields__[key] = field + Schema.__fields__[key] = field # type: ignore def update_from_model_extra(shallow_result_dict: Dict[str, Any], result: BaseModel) -> None: @@ -888,7 +888,8 @@ def _fill( # expecting the promise is typed Any so it doesn't fail # validation if it doesn't receive the function return value field = model_fields[key] - set_model_field(schema, key, copy_model_field(field, Any)) + new_field = copy_model_field(field, Any) + set_model_field(schema, key, new_field) promise_schema = cls.make_promise_schema(value, resolve=resolve) filled[key], validation[v_key], final[key] = cls._fill( value, @@ -971,8 +972,8 @@ def _fill( # If our schema doesn't allow extra values, we need to filter them # manually because .construct doesn't parse anything if get_model_config_extra(schema) in ("forbid", "extra"): - fields = get_model_fields_set(result) - exclude = [k for k in dict(result).keys() if k not in fields] + result_field_names = get_model_fields_set(result) + exclude = [k for k in dict(result).keys() if k not in result_field_names] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) # Do a shallow serialization first # If any of the sub-objects are Pydantic models, first check if they From 0bac230c6601f41c466dd48224e73c4ce6c5fb01 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:45:31 -0700 Subject: [PATCH 29/45] test fixes --- confection/tests/test_2.py | 36 --------------------------------- confection/tests/test_config.py | 2 +- 2 files changed, 1 insertion(+), 37 deletions(-) delete mode 100644 confection/tests/test_2.py diff --git a/confection/tests/test_2.py b/confection/tests/test_2.py deleted file mode 100644 index 290ecd7..0000000 --- a/confection/tests/test_2.py +++ /dev/null @@ -1,36 +0,0 @@ -import dataclasses -from typing import Union, Iterable -import catalogue -from confection import registry, Config -from pydantic import BaseModel - -# Create a new registry. -registry.optimizers = catalogue.create("confection", "optimizers", entry_points=False) - - -# Define a dummy optimizer class. - -@dataclasses.dataclass -class MyCoolOptimizer: - learn_rate: float - gamma: float - - -@registry.optimizers.register("my_cool_optimizer.v1") -def make_my_optimizer(learn_rate: Union[float, Iterable[float]], gamma: float): - return MyCoolOptimizer(learn_rate=learn_rate, gamma=gamma) - - -if __name__ == "__main__": - # Load the config file from disk, resolve it and fetch the instantiated optimizer object. - cfg_str = """ -[optimizer] -@optimizers = "my_cool_optimizer.v1" -learn_rate = 0.001 -gamma = 1e-8 - """ - config = Config().from_str(cfg_str) - resolved = registry.resolve(config) - optimizer = resolved["optimizer"] # MyCoolOptimizer(learn_rate=0.001, gamma=1e-08) - - print(config, resolved, optimizer) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 61c4fa5..e53f360 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -11,7 +11,7 @@ from pydantic.types import StrictBool from confection import ConfigValidationError, Config, get_model_fields, PYDANTIC_V2 -from confection.util import Generator, partial +from confection.util import partial from confection.tests.util import Cat, my_registry, make_tempdir From 6d95b503e9782e3777731575918f5c3355d08ad2 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Fri, 7 Jul 2023 13:49:12 -0700 Subject: [PATCH 30/45] re-enable mypy --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7fe9910..0bda763 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -47,8 +47,8 @@ jobs: python -m pip install -U -r requirements.txt python -m build --sdist - # - name: Run mypy - # run: python -m mypy confection + - name: Run mypy + run: python -m mypy confection/__init__.py confection/util.py - name: Delete source directory run: rm -rf confection From aa7d13bb6cd2f1c6c06c4419acf01d520b3869ff Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 3 Aug 2023 15:25:20 +0200 Subject: [PATCH 31/45] Undo unrelated changes to CI tests --- .github/workflows/tests.yml | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0bda763..e1aaafb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -2,8 +2,6 @@ name: tests on: push: - branches: - - main pull_request: types: [opened, synchronize, reopened, edited] @@ -13,21 +11,14 @@ jobs: if: github.repository == 'explosion/confection' strategy: fail-fast: false - # matrix: - # os: [ubuntu-latest, windows-latest, macos-latest] - # python_version: ["3.7", "3.8", "3.9", "3.10", "3.11"] matrix: + os: [ubuntu-latest, windows-latest, macos-latest] + python_version: ["3.7", "3.8", "3.9", "3.10", "3.11"] include: - - python_version: '3.7' - os: ubuntu-latest - - python_version: '3.8' - os: windows-latest - - python_version: '3.11' - os: ubuntu-latest - - python_version: '3.11' - os: windows-latest - - python_version: '3.11' - os: macos-latest + - os: windows-2019 + python_version: "3.6" + - os: ubuntu-20.04 + python_version: "3.6" runs-on: ${{ matrix.os }} @@ -48,7 +39,8 @@ jobs: python -m build --sdist - name: Run mypy - run: python -m mypy confection/__init__.py confection/util.py + run: python -m mypy confection + if: matrix.python_version != '3.6' - name: Delete source directory run: rm -rf confection From fc29ccd47c7745b43581000e39daef7af9478e03 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 3 Aug 2023 15:28:18 +0200 Subject: [PATCH 32/45] Ignore tests for mypy --- confection/tests/test_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index e53f360..34c1864 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -1,3 +1,4 @@ +# type: ignore import inspect import platform From 65e69c1450a2493445a41ebb59ddb8e64899dd6a Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 3 Aug 2023 15:33:11 +0200 Subject: [PATCH 33/45] Add mypy for pydantic v1 --- .github/workflows/tests.yml | 11 +++++++++-- confection/__init__.py | 20 ++++++++++---------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e1aaafb..5880f6c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -38,8 +38,15 @@ jobs: python -m pip install -U -r requirements.txt python -m build --sdist - - name: Run mypy - run: python -m mypy confection + - name: Run mypy for Pydantic v2 + run: | + python -m mypy confection + if: matrix.python_version != '3.6' + + - name: Run mypy for Pydantic v1 + run: | + python -m pip install -U pydantic==1.10.* + python -m mypy confection if: matrix.python_version != '3.6' - name: Delete source directory diff --git a/confection/__init__.py b/confection/__init__.py index e275629..6e7f403 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -671,7 +671,7 @@ def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: """ field_info = copy.deepcopy(field) if PYDANTIC_V2: - field_info.annotation = type_ + field_info.annotation = type_ # type: ignore else: field_info.type_ = type_ # type: ignore return field_info @@ -679,9 +679,9 @@ def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: def get_model_config_extra(model: Type[BaseModel]) -> str: if PYDANTIC_V2: - extra = str(model.model_config.get("extra", "forbid")) + extra = str(model.model_config.get("extra", "forbid")) # type: ignore else: - extra = str(model.Config.extra) or "forbid" + extra = str(model.Config.extra) or "forbid" # type: ignore assert isinstance(extra, str) return extra @@ -691,15 +691,15 @@ def get_model_config_extra(model: Type[BaseModel]) -> str: def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: - return Schema.model_validate(data) if PYDANTIC_V2 else Schema(**data) + return Schema.model_validate(data) if PYDANTIC_V2 else Schema(**data) # type: ignore def model_construct(Schema: Type[_ModelT], fields_set: Optional[Set[str]], data: Dict[str, Any]) -> _ModelT: - return Schema.model_construct(fields_set, **data) if PYDANTIC_V2 else Schema.construct(fields_set, **data) + return Schema.model_construct(fields_set, **data) if PYDANTIC_V2 else Schema.construct(fields_set, **data) # type: ignore def model_dump(instance: BaseModel) -> Dict[str, Any]: - return instance.model_dump() if PYDANTIC_V2 else instance.dict() + return instance.model_dump() if PYDANTIC_V2 else instance.dict() # type: ignore def get_field_annotation(field: FieldInfo) -> Type: @@ -715,20 +715,20 @@ def get_model_fields_set(Schema: Union[Type[BaseModel], BaseModel]) -> Set[str]: def get_model_extra(model: BaseModel) -> Dict[str, FieldInfo]: - return model.model_extra or {} if PYDANTIC_V2 else {} + return model.model_extra or {} if PYDANTIC_V2 else {} # type: ignore def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo): if PYDANTIC_V2: - Schema.model_fields[key] = field + Schema.model_fields[key] = field # type: ignore else: Schema.__fields__[key] = field # type: ignore def update_from_model_extra(shallow_result_dict: Dict[str, Any], result: BaseModel) -> None: if PYDANTIC_V2: - if result.model_extra is not None: - shallow_result_dict.update(result.model_extra) + if result.model_extra is not None: # type: ignore + shallow_result_dict.update(result.model_extra) # type: ignore def _safe_is_subclass(cls: type, expected: type) -> bool: From a9dd2a3f9da214959c372291bb812d58f491522a Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 3 Aug 2023 15:35:13 +0200 Subject: [PATCH 34/45] Format --- confection/__init__.py | 31 +++++++++++++++------------ confection/tests/test_config.py | 37 +++++++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 6e7f403..3ca1f24 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -411,11 +411,7 @@ def to_str(self, *, interpolate: bool = True) -> str: if hasattr(value, "items"): # Reference to a function with no arguments, serialize # inline as a dict and don't create new section - if ( - registry.is_promise(value) - and len(value) == 1 - and is_kwarg - ): + if registry.is_promise(value) and len(value) == 1 and is_kwarg: flattened.set(section_name, key, try_dump_json(value, node)) else: queue.append((path + (key,), value)) @@ -686,7 +682,6 @@ def get_model_config_extra(model: Type[BaseModel]) -> str: return extra - _ModelT = TypeVar("_ModelT", bound=BaseModel) @@ -694,7 +689,9 @@ def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: return Schema.model_validate(data) if PYDANTIC_V2 else Schema(**data) # type: ignore -def model_construct(Schema: Type[_ModelT], fields_set: Optional[Set[str]], data: Dict[str, Any]) -> _ModelT: +def model_construct( + Schema: Type[_ModelT], fields_set: Optional[Set[str]], data: Dict[str, Any] +) -> _ModelT: return Schema.model_construct(fields_set, **data) if PYDANTIC_V2 else Schema.construct(fields_set, **data) # type: ignore @@ -725,7 +722,9 @@ def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo): Schema.__fields__[key] = field # type: ignore -def update_from_model_extra(shallow_result_dict: Dict[str, Any], result: BaseModel) -> None: +def update_from_model_extra( + shallow_result_dict: Dict[str, Any], result: BaseModel +) -> None: if PYDANTIC_V2: if result.model_extra is not None: # type: ignore shallow_result_dict.update(result.model_extra) # type: ignore @@ -860,7 +859,7 @@ def _fill( resolve: bool = True, parent: str = "", overrides: Dict[str, Dict[str, Any]] = {}, - resolved_object_keys: Set[str] = set() + resolved_object_keys: Set[str] = set(), ) -> Tuple[ Union[Dict[str, Any], Config], Union[Dict[str, Any], Config], Dict[str, Any] ]: @@ -898,7 +897,7 @@ def _fill( resolve=resolve, parent=key_parent, overrides=overrides, - resolved_object_keys=resolved_object_keys + resolved_object_keys=resolved_object_keys, ) reg_name, func_name = cls.get_constructor(final[key]) args, kwargs = cls.parse_args(final[key]) @@ -911,7 +910,9 @@ def _fill( # here, because we want the traceback if the function fails. getter_result = getter(*args, **kwargs) - if isinstance(getter_result, BaseModel) or is_dataclass(getter_result): + if isinstance(getter_result, BaseModel) or is_dataclass( + getter_result + ): resolved_object_keys.add(key) else: # We're not resolving and calling the function, so replace @@ -932,7 +933,9 @@ def _fill( if key in fields: field = fields[key] annotation = get_field_annotation(field) - if annotation is not None and _safe_is_subclass(annotation, BaseModel): + if annotation is not None and _safe_is_subclass( + annotation, BaseModel + ): field_type = annotation filled[key], validation[v_key], final[key] = cls._fill( value, @@ -973,7 +976,9 @@ def _fill( # manually because .construct doesn't parse anything if get_model_config_extra(schema) in ("forbid", "extra"): result_field_names = get_model_fields_set(result) - exclude = [k for k in dict(result).keys() if k not in result_field_names] + exclude = [ + k for k in dict(result).keys() if k not in result_field_names + ] exclude_validation = set([ARGS_FIELD_ALIAS, *RESERVED_FIELDS.keys()]) # Do a shallow serialization first # If any of the sub-objects are Pydantic models, first check if they diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 34c1864..d30af33 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -256,7 +256,10 @@ class TestSchema(BaseModel): my_registry.resolve({"cfg": config}, schema=TestSchema) -@pytest.mark.skipif(PYDANTIC_V2, reason="In Pydantic v2, int/float cannot be coerced to str so this test will fail.") +@pytest.mark.skipif( + PYDANTIC_V2, + reason="In Pydantic v2, int/float cannot be coerced to str so this test will fail.", +) def test_resolve_schema_coerced(): class TestBaseSchema(BaseModel): test1: str @@ -287,14 +290,20 @@ def test_read_config(): assert cfg["pipeline"]["classifier"]["model"]["embedding"]["width"] == 128 -@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG], ids=["dataclasses", "pydantic"]) +@pytest.mark.parametrize( + "optimizer_cfg_str", + [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG], + ids=["dataclasses", "pydantic"], +) def test_optimizer_config(optimizer_cfg_str: str): cfg = Config().from_str(optimizer_cfg_str) optimizer = my_registry.resolve(cfg, validate=True)["optimizer"] assert optimizer.beta1 == 0.9 -@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) +@pytest.mark.parametrize( + "optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG] +) def test_config_to_str(optimizer_cfg_str: str): cfg = Config().from_str(optimizer_cfg_str) assert cfg.to_str().strip() == optimizer_cfg_str.strip() @@ -315,7 +324,9 @@ def test_config_to_str_creates_intermediate_blocks(): ) -@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) +@pytest.mark.parametrize( + "optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG] +) def test_config_roundtrip_bytes(optimizer_cfg_str: str): cfg = Config().from_str(optimizer_cfg_str) cfg_bytes = cfg.to_bytes() @@ -323,7 +334,9 @@ def test_config_roundtrip_bytes(optimizer_cfg_str: str): assert new_cfg.to_str().strip() == optimizer_cfg_str.strip() -@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) +@pytest.mark.parametrize( + "optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG] +) def test_config_roundtrip_disk(optimizer_cfg_str: str): cfg = Config().from_str(optimizer_cfg_str) with make_tempdir() as path: @@ -332,8 +345,13 @@ def test_config_roundtrip_disk(optimizer_cfg_str: str): new_cfg = Config().from_disk(cfg_path) assert new_cfg.to_str().strip() == optimizer_cfg_str.strip() -@pytest.mark.parametrize("optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG]) -def test_config_roundtrip_disk_respects_path_subclasses(pathy_fixture, optimizer_cfg_str: str): + +@pytest.mark.parametrize( + "optimizer_cfg_str", [OPTIMIZER_DATACLASS_CFG, OPTIMIZER_PYDANTIC_CFG] +) +def test_config_roundtrip_disk_respects_path_subclasses( + pathy_fixture, optimizer_cfg_str: str +): cfg = Config().from_str(optimizer_cfg_str) cfg_path = pathy_fixture / "config.cfg" cfg.to_disk(cfg_path) @@ -360,7 +378,6 @@ def test_validation_custom_types(): else: log_field = Field("ERROR", regex="(DEBUG|INFO|WARNING|ERROR)") - def complex_args( rate: StrictFloat, steps: PositiveInt = 10, @@ -1436,6 +1453,8 @@ def test_warn_single_quotes(): def test_parse_strings_interpretable_as_ints(): """Test whether strings interpretable as integers are parsed correctly (i. e. as strings).""" - cfg = Config().from_str(f"""[a]\nfoo = [${{b.bar}}, "00${{b.bar}}", "y"]\n\n[b]\nbar = 3""") + cfg = Config().from_str( + f"""[a]\nfoo = [${{b.bar}}, "00${{b.bar}}", "y"]\n\n[b]\nbar = 3""" + ) assert cfg["a"]["foo"] == [3, "003", "y"] assert cfg["b"]["bar"] == 3 From 36d1d1ca3473accf19de14ba6a200b6e2e0cc2b0 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 3 Aug 2023 15:44:31 +0200 Subject: [PATCH 35/45] Lower typing_extensions pin for python 3.6 --- requirements.txt | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 0cb4ae3..3ae08e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ pydantic>=1.7.4,!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0 -typing_extensions>=4.5.0; python_version < "3.8" +typing_extensions>=4.2.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 # Development requirements pathy>=0.3.5 diff --git a/setup.cfg b/setup.cfg index a4e4895..eb4c665 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,7 @@ include_package_data = true python_requires = >=3.6 install_requires = pydantic>=1.7.4,!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0 - typing_extensions>=4.5.0; python_version < "3.8" + typing_extensions>=4.2.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 [sdist] From de144314e30b6aa88f903125ef8d194c9a40cbd8 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 3 Aug 2023 15:46:16 +0200 Subject: [PATCH 36/45] Undo changes to typing_extensions --- requirements.txt | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 3ae08e3..6ddd138 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ pydantic>=1.7.4,!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0 -typing_extensions>=4.2.0; python_version < "3.8" +typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 # Development requirements pathy>=0.3.5 diff --git a/setup.cfg b/setup.cfg index eb4c665..fe44a3a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,7 +31,7 @@ include_package_data = true python_requires = >=3.6 install_requires = pydantic>=1.7.4,!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0 - typing_extensions>=4.2.0; python_version < "3.8" + typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 [sdist] From 3283e4a15f3a3b851c7021a67063c3a150c87734 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 3 Aug 2023 15:50:50 +0200 Subject: [PATCH 37/45] Allow older pydantic v1 for tests for python 3.6 --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5880f6c..e7b0f10 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,7 +43,7 @@ jobs: python -m mypy confection if: matrix.python_version != '3.6' - - name: Run mypy for Pydantic v1 + - name: Run mypy for Pydantic v1.10 run: | python -m pip install -U pydantic==1.10.* python -m mypy confection @@ -70,7 +70,7 @@ jobs: - name: Install test requirements with Pydantic v1 run: | python -m pip install -U -r requirements.txt - python -m pip install -U pydantic==1.10.* + python -m pip install -U "pydantic<2.0" - name: Run tests for Pydantic v1 run: | From cbada4bed45698c3953e98b902ea01181586e2d1 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 3 Aug 2023 17:01:00 +0200 Subject: [PATCH 38/45] Add CI test for spacy init config --- .github/workflows/tests.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e7b0f10..04f2c97 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -91,3 +91,8 @@ jobs: run: | python -m pip install hypothesis python -c "import confection; import hypothesis" + + - name: Test with spacy + run: | + python -m pip install spacy + python -m spacy init config -p tagger tagger.cfg From 680e224d340397ab7d70e7e926b01f0df9d7bf89 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Thu, 3 Aug 2023 10:22:32 -0700 Subject: [PATCH 39/45] black formatting --- confection/__init__.py | 1 - confection/tests/test_config.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/confection/__init__.py b/confection/__init__.py index 4e084d7..cba2949 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -963,7 +963,6 @@ def _fill( try: result = model_validate(schema, validation) except ValidationError as e: - raise ConfigValidationError( config=config, errors=e.errors(), parent=parent ) from None diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index d30af33..946f81e 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -1340,6 +1340,7 @@ class BaseSchema(BaseModel): assert filled2["catsie"]["cute"] is True resolved = my_registry.resolve(filled2) assert resolved["catsie"] == "meow" + # With unavailable function class BaseSchema2(BaseModel): catsie: Any From 4f7d5b3b582d93349282306116fead0d712fd60b Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Thu, 3 Aug 2023 23:39:55 -0700 Subject: [PATCH 40/45] Fix spacy init issue (#37) * use old implementation of modelfield copy * ignore type error * Update __init__.py --------- Co-authored-by: Adriane Boyd --- confection/__init__.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 3ca1f24..8cc92c5 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -19,6 +19,12 @@ from .util import Decorator, SimpleFrozenDict, SimpleFrozenList, PYDANTIC_V2 +if PYDANTIC_V2: + from pydantic.v1.fields import ModelField # type: ignore +else: + from pydantic.fields import ModelField # type: ignore + + # Field used for positional arguments, e.g. [section.*.xyz]. The alias is # required for the schema (shouldn't clash with user-defined arg names) ARGS_FIELD = "*" @@ -661,16 +667,28 @@ def alias_generator(name: str) -> str: return name +def _copy_model_field_v1(field: ModelField, type_: Any) -> ModelField: + return ModelField( + name=field.name, + type_=type_, + class_validators=field.class_validators, + model_config=field.model_config, + default=field.default, + default_factory=field.default_factory, + required=field.required, + ) + + def copy_model_field(field: FieldInfo, type_: Any) -> FieldInfo: """Copy a model field and assign a new type, e.g. to accept an Any type even though the original value is typed differently. """ - field_info = copy.deepcopy(field) if PYDANTIC_V2: + field_info = copy.deepcopy(field) field_info.annotation = type_ # type: ignore + return field_info else: - field_info.type_ = type_ # type: ignore - return field_info + return _copy_model_field_v1(field, type_) # type: ignore def get_model_config_extra(model: Type[BaseModel]) -> str: From c4f78e80190c9413941a423dcd812572c22fbbb6 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Fri, 4 Aug 2023 08:43:26 +0200 Subject: [PATCH 41/45] Simplify pydantic requirements --- requirements.txt | 2 +- setup.cfg | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 6ddd138..e107fba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -pydantic>=1.7.4,!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0 +pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0 typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 # Development requirements diff --git a/setup.cfg b/setup.cfg index fe44a3a..69167f7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,7 +30,7 @@ zip_safe = true include_package_data = true python_requires = >=3.6 install_requires = - pydantic>=1.7.4,!=1.8,!=1.8.1,!=2.0.0,!=2.0.1,<3.0.0 + pydantic>=1.7.4,!=1.8,!=1.8.1,<3.0.0 typing_extensions>=3.7.4.1,<4.5.0; python_version < "3.8" srsly>=2.4.0,<3.0.0 From 8ae9252b77662a6e3072ea0d2406fb6e7d4690c6 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Thu, 31 Aug 2023 12:30:47 -0700 Subject: [PATCH 42/45] add a spacy init config regression test if spacy is installed --- confection/tests/test_config.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/confection/tests/test_config.py b/confection/tests/test_config.py index 5f1344c..b177ca5 100644 --- a/confection/tests/test_config.py +++ b/confection/tests/test_config.py @@ -1457,3 +1457,14 @@ def test_parse_strings_interpretable_as_ints(): ) assert cfg["a"]["foo"] == [3, "003", "y"] assert cfg["b"]["bar"] == 3 + + +def test_spacy_init_config(): + """Regression test to ensure spacy init config works""" + try: + from spacy.cli import init_config + except ImportError: + pytest.skip("SpaCy not installed") + + config = init_config(pipeline=["tagger"]) + assert isinstance(config, Config) From 883d76b049ee0ebb4cae167465e4ea16dbde8aef Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Thu, 31 Aug 2023 12:37:02 -0700 Subject: [PATCH 43/45] rm unused import --- confection/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/confection/__init__.py b/confection/__init__.py index 187c7a9..9f45658 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -18,7 +18,6 @@ from pathlib import Path from types import GeneratorType from typing import ( - TYPE_CHECKING, Any, Callable, Dict, From d25d81adc16592e4bd40b7c96af7d1a11dfb3cb4 Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Thu, 31 Aug 2023 12:37:22 -0700 Subject: [PATCH 44/45] rm spacy init config step in gha in favor of test in pytest --- .github/workflows/tests.yml | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8ed7de1..e52a13f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -97,7 +97,7 @@ jobs: - name: Install test requirements with Pydantic v1 run: | python -m pip install -U -r requirements.txt - python -m pip install -U "pydantic<2.0" + python -m pip install -U "pydantic<2.0" "spacy" - name: Run tests for Pydantic v1 run: | @@ -106,7 +106,7 @@ jobs: - name: Install test requirements with Pydantic v2 run: | - python -m pip install -U -r requirements.txt + python -m pip install -U -r requirements.txt spacy python -m pip install -U pydantic - name: Run tests for Pydantic v2 @@ -118,8 +118,3 @@ jobs: run: | python -m pip install hypothesis python -c "import confection; import hypothesis" - - - name: Test with spacy - run: | - python -m pip install spacy - python -m spacy init config -p tagger tagger.cfg From ba068d60dd596d30075516d533f7c96727bc197e Mon Sep 17 00:00:00 2001 From: Kabir Khan Date: Thu, 31 Aug 2023 20:06:21 -0700 Subject: [PATCH 45/45] fix checks --- confection/__init__.py | 46 +++++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/confection/__init__.py b/confection/__init__.py index 9f45658..edcde34 100644 --- a/confection/__init__.py +++ b/confection/__init__.py @@ -728,18 +728,37 @@ def get_model_config_extra(model: Type[BaseModel]) -> str: _ModelT = TypeVar("_ModelT", bound=BaseModel) +def _schema_is_pydantic_v2(Schema: Union[Type[BaseModel], BaseModel]) -> bool: + """If `model_fields` attr is present, it means we have a schema or instance + of a pydantic v2 BaseModel. Even if we're using Pydantic V2, users could still + import from `pydantic.v1` and that would break our compat checks. + Schema (Union[Type[BaseModel], BaseModel]): Input schema or instance. + RETURNS (bool): True if the pydantic model is a v2 model or not + """ + return hasattr(Schema, "model_fields") + + def model_validate(Schema: Type[_ModelT], data: Dict[str, Any]) -> _ModelT: - return Schema.model_validate(data) if PYDANTIC_V2 else Schema(**data) # type: ignore + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): + return Schema.model_validate(data) # type: ignore + else: + return Schema.validate(data) # type: ignore def model_construct( Schema: Type[_ModelT], fields_set: Optional[Set[str]], data: Dict[str, Any] ) -> _ModelT: - return Schema.model_construct(fields_set, **data) if PYDANTIC_V2 else Schema.construct(fields_set, **data) # type: ignore + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): + return Schema.model_construct(fields_set, **data) # type: ignore + else: + return Schema.construct(fields_set, **data) # type: ignore def model_dump(instance: BaseModel) -> Dict[str, Any]: - return instance.model_dump() if PYDANTIC_V2 else instance.dict() # type: ignore + if PYDANTIC_V2 and _schema_is_pydantic_v2(instance): + return instance.model_dump() # type: ignore + else: + return instance.dict() def get_field_annotation(field: FieldInfo) -> Type: @@ -747,19 +766,28 @@ def get_field_annotation(field: FieldInfo) -> Type: def get_model_fields(Schema: Union[Type[BaseModel], BaseModel]) -> Dict[str, FieldInfo]: - return Schema.model_fields if PYDANTIC_V2 else Schema.__fields__ # type: ignore + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): + return Schema.model_fields # type: ignore + else: + return Schema.__fields__ # type: ignore def get_model_fields_set(Schema: Union[Type[BaseModel], BaseModel]) -> Set[str]: - return Schema.model_fields_set if PYDANTIC_V2 else Schema.__fields_set__ # type: ignore + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): + return Schema.model_fields_set # type: ignore + else: + return Schema.__fields_set__ # type: ignore -def get_model_extra(model: BaseModel) -> Dict[str, FieldInfo]: - return model.model_extra or {} if PYDANTIC_V2 else {} # type: ignore +def get_model_extra(instance: BaseModel) -> Dict[str, FieldInfo]: + if PYDANTIC_V2 and _schema_is_pydantic_v2(instance): + return instance.model_extra # type: ignore + else: + return {} def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo): - if PYDANTIC_V2: + if PYDANTIC_V2 and _schema_is_pydantic_v2(Schema): Schema.model_fields[key] = field # type: ignore else: Schema.__fields__[key] = field # type: ignore @@ -768,7 +796,7 @@ def set_model_field(Schema: Type[BaseModel], key: str, field: FieldInfo): def update_from_model_extra( shallow_result_dict: Dict[str, Any], result: BaseModel ) -> None: - if PYDANTIC_V2: + if PYDANTIC_V2 and _schema_is_pydantic_v2(result): if result.model_extra is not None: # type: ignore shallow_result_dict.update(result.model_extra) # type: ignore