From bb4195563e0b468109c9c4f59f9a0b7400c767a8 Mon Sep 17 00:00:00 2001 From: Mart Ratas Date: Wed, 27 Nov 2024 11:51:32 +0000 Subject: [PATCH] CU-8695d4www pydantic 2 (#476) * CU-8695d4www: Bump pydantic requirement to 2.6+ * CU-8695d4www: Update methods to use pydantic2 based ones * CU-8695d4www: Update methods to use pydantic2 based ones [part 2] * CU-8695d4www: Use identifier based config when setting last train date on meta cat and tner * CU-8695d4www: Use pydantic2-based model validation * CU-8695d4www: Add workarounds for pydantic1 methods * CU-8695d4www: Add missing utils module for pydantic1 methods * Revert "CU-8695d4www: Bump pydantic requirement to 2.6+" This reverts commit b0b3d431cc01e2e73c8708bd007e7e948263deb9. * CU-8695d4www: [TEMP] Add type-ingores to pydantic2-based methods for GHA workflow * CU-8695d4www: Make pydantic2-requires getattribute wrapper only apply when appropriate * CU-8695d4www: Fix missin model dump getter abstraction * CU-8695d4www: Fix missin model dump getter abstraction (in CAT) * CU-8695d4www: Update tests for pydantic 1 and 2 support * Revert "CU-8695d4www: [TEMP] Add type-ingores to pydantic2-based methods for GHA workflow" This reverts commit b86135add8c8ad944a83c0c51f425a3e55d940e7. * Reapply "CU-8695d4www: Bump pydantic requirement to 2.6+" This reverts commit 080ae7172434a849e81ae0662d9310588f4bb9a3. * CU-8695d4www: Allow both pydantic 1 and 2 * CU-8695d4www: Deprecated pydantic utils for removal in 1.15 * CU-8695d4www: Allow usage of specified deprecated method(s) during tests * CU-8695d4www: Allow usage of pydantic 1-2 workaround methods during tests * CU-8695d4www: Add documentation for argument allowing usage during tests in deprecation method * CU-8695d4www: Fix allowing deprecation during test time * CU-8695d4www: Fix model dump getting in regression checker * Revert "CU-8695d4www: Fix allowing deprecation during test time" This reverts commit fadc7d18e695e8217e26e5191ff11e36dac44fde. * Revert "CU-8695d4www: Add documentation for argument allowing usage during tests in deprecation method" This reverts commit 927f8078083f1cc033776aa1407a424ac5601391. * Revert "CU-8695d4www: Allow usage of pydantic 1-2 workaround methods during tests" This reverts commit 825628e10db04b45143e8ec84af71781a04d7725. * Revert "CU-8695d4www: Allow usage of specified deprecated method(s) during tests" This reverts commit a89e6804e2d533cc21e4409d0045bb7c63bf743f. * Revert "CU-8695d4www: Deprecated pydantic utils for removal in 1.15" This reverts commit 0ee1a8abc3fa429beb3094c4ff465876d41677e6. * CU-8695d4www: Add comment regarding pydantic backwards compatiblity where applicable * CU-8695d4www: Add pydantic 1 check to GHA workflow * CU-8695d4www: Fix usage of pydantic-1 based dict method in regression results * CU-8695d4www: Fix usage of pydantic-1 based dict method in regression tests * CU-8695d4www: New workflow step to install and run mypy on pydantic 1 * CU-8695d4www: Add type ignore comments to pydantic2 versions in versioning utils for typing during GHA workflow * CU-8695d4www: Update pydantic requirement to 2.0+ only * CU-8695d4www: Update to pydantic 2 ONLY * CU-869671bn4: Update mypy dev requirement to be less than 1.12 * CU-869671bn4: Fix model fields in config * CU-869671bn4: Fix stats helper method - use correct type adapter * CU-869671bn4: Fix some model type issues * CU-869671bn4: Line up with previous model dump methods * CU-869671bn4: Fix overwriting model dump methods * CU-869671bn4: Remove pydantic1 workflow step --- .github/workflows/main.yml | 8 +++- install_requires.txt | 2 +- medcat/cat.py | 2 +- medcat/config.py | 41 ++++++++++--------- medcat/config_meta_cat.py | 12 +++--- medcat/config_rel_cat.py | 8 ++-- medcat/config_transformers_ner.py | 8 ++-- medcat/meta_cat.py | 6 +-- medcat/ner/transformers_ner.py | 6 +-- medcat/utils/regression/checking.py | 2 +- medcat/utils/regression/regression_checker.py | 4 +- medcat/utils/regression/results.py | 22 +++++----- tests/stats/helpers.py | 13 +++++- tests/stats/test_kfold.py | 9 +--- tests/stats/test_mctexport.py | 5 +-- tests/utils/regression/test_checking.py | 8 ++-- tests/utils/regression/test_results.py | 6 +-- 17 files changed, 86 insertions(+), 76 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b4a84f16d..1b7232bb6 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,6 +31,13 @@ jobs: - name: Lint run: | flake8 medcat + - name: Pydantic 1 check + # NOTE: the following will look for use of pydantic1-specific .dict() method and .__fields__ attribute + # if there are some (that are not annotated for pydantic1 backwards compatibility) a non-zero exit + # code is returned, which will hald the workflow and print out the offending parts + run: | + grep "\.__fields__" medcat -rI | grep -v "# 4pydantic1 - backwards compatibility" | tee /dev/stderr | test $(wc -l) -eq 0 + grep "\.dict(" medcat -rI | grep -v "# 4pydantic1 - backwards compatibility" | tee /dev/stderr | test $(wc -l) -eq 0 - name: Test run: | all_files=$(git ls-files | grep '^tests/.*\.py$' | grep -v '/__init__\.py$' | sed 's/\.py$//' | sed 's/\//./g') @@ -54,7 +61,6 @@ jobs: repo: context.repo.repo }); core.setOutput('latest_version', latestRelease.data.tag_name); - - name: Make sure there's no deprecated methods that should be removed. # only run this for master -> production PR. I.e just before doing a release. if: github.event.pull_request.base.ref == 'main' && github.event.pull_request.head.ref == 'production' diff --git a/install_requires.txt b/install_requires.txt index 77b610825..136728d89 100644 --- a/install_requires.txt +++ b/install_requires.txt @@ -19,6 +19,6 @@ 'xxhash>=3.0.0' # allow later versions, tested with 3.1.0 'blis>=0.7.5,<1.0.0' # allow later versions, tested with 0.7.9, avoid 1.0.0 (depends on numpy 2) 'click>=8.0.4' # allow later versions, tested with 8.1.3 -'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes +'pydantic>=2.0.0,<3.0' # avoid next major release "humanfriendly~=10.0" # for human readable file / RAM sizes "peft>=0.8.2" diff --git a/medcat/cat.py b/medcat/cat.py index 707dbd7f3..13042acd0 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -590,7 +590,7 @@ def _print_stats(self, def _init_ckpts(self, is_resumed, checkpoint): if self.config.general.checkpoint.steps is not None or checkpoint is not None: - checkpoint_config = CheckpointConfig(**self.config.general.checkpoint.dict()) + checkpoint_config = CheckpointConfig(**self.config.general.checkpoint.model_dump()) checkpoint_manager = CheckpointManager('cat_train', checkpoint_config) if is_resumed: # TODO: probably remove is_resumed mark and always resume if a checkpoint is provided, diff --git a/medcat/config.py b/medcat/config.py index a1dd15e78..fb3ab5f03 100644 --- a/medcat/config.py +++ b/medcat/config.py @@ -1,6 +1,5 @@ from datetime import datetime -from pydantic import BaseModel, Extra, ValidationError -from pydantic.fields import ModelField +from pydantic import BaseModel, ValidationError from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union, Type, Literal from multiprocessing import cpu_count import logging @@ -125,7 +124,7 @@ def merge_config(self, config_dict: Dict) -> None: attr = None # new attribute value = config_dict[key] if isinstance(value, BaseModel): - value = value.dict() + value = value.model_dump() if isinstance(attr, MixingConfig): attr.merge_config(value) else: @@ -177,7 +176,7 @@ def rebuild_re(self) -> None: def _calc_hash(self, hasher: Optional[Hasher] = None) -> Hasher: if hasher is None: hasher = Hasher() - for _, v in cast(BaseModel, self).dict().items(): + for _, v in cast(BaseModel, self).model_dump().items(): if isinstance(v, MixingConfig): v._calc_hash(hasher) else: @@ -189,7 +188,7 @@ def get_hash(self, hasher: Optional[Hasher] = None): return hasher.hexdigest() def __str__(self) -> str: - return str(cast(BaseModel, self).dict()) + return str(cast(BaseModel, self).model_dump()) @classmethod def load(cls, save_path: str) -> "MixingConfig": @@ -238,15 +237,15 @@ def asdict(self) -> Dict[str, Any]: Returns: Dict[str, Any]: The dictionary associated with this config """ - return cast(BaseModel, self).dict() + return cast(BaseModel, self).model_dump() - def fields(self) -> Dict[str, ModelField]: + def fields(self) -> dict: """Get the fields associated with this config. Returns: - Dict[str, ModelField]: The dictionary of the field names and fields + dict: The dictionary of the field names and fields """ - return cast(BaseModel, self).__fields__ + return cast(BaseModel, self).model_fields class VersionInfo(MixingConfig, BaseModel): @@ -272,7 +271,7 @@ class VersionInfo(MixingConfig, BaseModel): """Which version of medcat was used to build the CDB""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -290,7 +289,7 @@ class CDBMaker(MixingConfig, BaseModel): """Minimum number of letters required in a name to be accepted for a concept""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -303,7 +302,7 @@ class AnnotationOutput(MixingConfig, BaseModel): include_text_in_output: bool = False class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -317,7 +316,7 @@ class CheckPoint(MixingConfig, BaseModel): """When training the maximum checkpoints will be kept on the disk""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -354,7 +353,7 @@ class General(MixingConfig, BaseModel): NB! For these changes to take effect, the pipe would need to be recreated.""" checkpoint: CheckPoint = CheckPoint() - usage_monitor = UsageMonitor() + usage_monitor: UsageMonitor = UsageMonitor() """Checkpointing config""" log_level: int = logging.INFO """Logging config for everything | 'tagger' can be disabled, but will cause a drop in performance""" @@ -395,7 +394,7 @@ class General(MixingConfig, BaseModel): reliable due to not taking into account all the details of the changes.""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -424,7 +423,7 @@ class Preprocessing(MixingConfig, BaseModel): NB! For these changes to take effect, the pipe would need to be recreated.""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -444,7 +443,7 @@ class Ner(MixingConfig, BaseModel): """Try reverse word order for short concepts (2 words max), e.g. heart disease -> disease heart""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -579,7 +578,7 @@ class Linking(MixingConfig, BaseModel): """If true when the context of a concept is calculated (embedding) the words making that concept are not taken into account""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -600,7 +599,7 @@ class Config: # this if for word_skipper and punct_checker which would otherwise # not have a validator arbitrary_types_allowed = True - extra = Extra.allow + extra = 'allow' validate_assignment = True def __init__(self, *args, **kwargs): @@ -618,7 +617,7 @@ def rebuild_re(self) -> None: # Override def get_hash(self): hasher = Hasher() - for k, v in self.dict().items(): + for k, v in self.model_dump().items(): if k in ['hash', ]: # ignore hash continue @@ -674,4 +673,6 @@ def wrapper(*args, **kwargs): # we get a nicer exceptio _waf_advice = "You can use `cat.cdb.weighted_average_function` to access it directly" Linking.__getattribute__ = _wrapper(Linking.__getattribute__, Linking, _waf_advice, AttributeError) # type: ignore +if hasattr(Linking, '__getattr__'): + Linking.__getattr__ = _wrapper(Linking.__getattr__, Linking, _waf_advice, AttributeError) # type: ignore Linking.__getitem__ = _wrapper(Linking.__getitem__, Linking, _waf_advice, KeyError) # type: ignore diff --git a/medcat/config_meta_cat.py b/medcat/config_meta_cat.py index ef8f908f2..0d6eb7a64 100644 --- a/medcat/config_meta_cat.py +++ b/medcat/config_meta_cat.py @@ -1,5 +1,5 @@ from typing import Dict, Any -from medcat.config import MixingConfig, BaseModel, Optional, Extra +from medcat.config import MixingConfig, BaseModel, Optional class General(MixingConfig, BaseModel): @@ -65,7 +65,7 @@ class General(MixingConfig, BaseModel): Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -169,7 +169,7 @@ class Model(MixingConfig, BaseModel): """If set to True center positions will be ignored when calculating representation""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -191,7 +191,7 @@ class Train(MixingConfig, BaseModel): """If set only this CUIs will be used for training""" auto_save_model: bool = True """Should do model be saved during training for best results""" - last_train_on: Optional[int] = None + last_train_on: Optional[float] = None """When was the last training run""" metric: Dict[str, str] = {'base': 'weighted avg', 'score': 'f1-score'} """What metric should be used for choosing the best model""" @@ -206,7 +206,7 @@ class Train(MixingConfig, BaseModel): """Focal Loss hyperparameter - determines importance the loss gives to hard-to-classify examples""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -217,5 +217,5 @@ class ConfigMetaCAT(MixingConfig, BaseModel): train: Train = Train() class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True diff --git a/medcat/config_rel_cat.py b/medcat/config_rel_cat.py index dfa3b0099..c16735d66 100644 --- a/medcat/config_rel_cat.py +++ b/medcat/config_rel_cat.py @@ -1,6 +1,6 @@ import logging from typing import Dict, Any, List -from medcat.config import MixingConfig, BaseModel, Optional, Extra +from medcat.config import MixingConfig, BaseModel, Optional class General(MixingConfig, BaseModel): @@ -89,7 +89,7 @@ class Model(MixingConfig, BaseModel): """If set to True center positions will be ignored when calculating representation""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -116,7 +116,7 @@ class Train(MixingConfig, BaseModel): """Should the model be saved during training for best results""" class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -127,5 +127,5 @@ class ConfigRelCAT(MixingConfig, BaseModel): train: Train = Train() class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True diff --git a/medcat/config_transformers_ner.py b/medcat/config_transformers_ner.py index 9f3102acb..e9661aaf2 100644 --- a/medcat/config_transformers_ner.py +++ b/medcat/config_transformers_ner.py @@ -1,4 +1,4 @@ -from medcat.config import MixingConfig, BaseModel, Optional, Extra +from medcat.config import MixingConfig, BaseModel, Optional class General(MixingConfig, BaseModel): @@ -16,11 +16,11 @@ class General(MixingConfig, BaseModel): chunking_overlap_window: Optional[int] = 5 """Size of the overlap window used for chunking""" test_size: float = 0.2 - last_train_on: Optional[int] = None + last_train_on: Optional[float] = None verbose_metrics: bool = False class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True @@ -29,5 +29,5 @@ class ConfigTransformersNER(MixingConfig, BaseModel): general: General = General() class Config: - extra = Extra.allow + extra = 'allow' validate_assignment = True diff --git a/medcat/meta_cat.py b/medcat/meta_cat.py index 386bbe0cf..9182fe00e 100644 --- a/medcat/meta_cat.py +++ b/medcat/meta_cat.py @@ -114,8 +114,8 @@ def get_hash(self) -> str: """ hasher = Hasher() # Set last_train_on if None - if self.config.train['last_train_on'] is None: - self.config.train['last_train_on'] = datetime.now().timestamp() + if self.config.train.last_train_on is None: + self.config.train.last_train_on = datetime.now().timestamp() hasher.update(self.config.get_hash()) return hasher.hexdigest() @@ -310,7 +310,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data # Save everything now self.save(save_dir_path=save_dir_path) - self.config.train['last_train_on'] = datetime.now().timestamp() + self.config.train.last_train_on = datetime.now().timestamp() return report def eval(self, json_path: str) -> Dict: diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 1de8d6d83..392b4a94d 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -103,8 +103,8 @@ def get_hash(self) -> str: """ hasher = Hasher() # Set last_train_on if None - if self.config.general['last_train_on'] is None: - self.config.general['last_train_on'] = datetime.now().timestamp() + if self.config.general.last_train_on is None: + self.config.general.last_train_on = datetime.now().timestamp() hasher.update(self.config.get_hash()) return hasher.hexdigest() @@ -242,7 +242,7 @@ def train(self, trainer.train() # type: ignore # Save the training time - self.config.general['last_train_on'] = datetime.now().timestamp() # type: ignore + self.config.general.last_train_on = datetime.now().timestamp() # type: ignore # Save everything self.save(save_dir_path=os.path.join(self.training_arguments.output_dir, 'final_model')) diff --git a/medcat/utils/regression/checking.py b/medcat/utils/regression/checking.py index 2c2d52ce9..f470fc2ce 100644 --- a/medcat/utils/regression/checking.py +++ b/medcat/utils/regression/checking.py @@ -411,7 +411,7 @@ def to_dict(self) -> dict: d = {} for case in self.cases: d[case.name] = case.to_dict() - d['meta'] = self.metadata.dict() + d['meta'] = self.metadata.model_dump() fix_np_float64(d['meta']) return d diff --git a/medcat/utils/regression/regression_checker.py b/medcat/utils/regression/regression_checker.py index 4906e9db8..d494a40dc 100644 --- a/medcat/utils/regression/regression_checker.py +++ b/medcat/utils/regression/regression_checker.py @@ -118,8 +118,8 @@ def main(model_pack_dir: Path, test_suite_file: Path, examples_strictness = Strictness[examples_strictness_str] if jsonpath: logger.info('Writing to %s', str(jsonpath)) - jsonpath.write_text(json.dumps(res.dict(strictness=examples_strictness), - indent=jsonindent)) + dumped = res.model_dump(strictness=examples_strictness) + jsonpath.write_text(json.dumps(dumped, indent=jsonindent)) else: logger.info(res.get_report(phrases_separately=phrases, hide_empty=hide_empty, examples_strictness=examples_strictness, diff --git a/medcat/utils/regression/results.py b/medcat/utils/regression/results.py index 2667a970a..1a6bf8932 100644 --- a/medcat/utils/regression/results.py +++ b/medcat/utils/regression/results.py @@ -1,5 +1,5 @@ from enum import Enum, auto -from typing import Dict, List, Optional, Any, Set, Iterable, Tuple +from typing import Dict, List, Optional, Any, Set, Iterable, Tuple, cast import json import pydantic @@ -372,7 +372,7 @@ def get_report(self) -> str: ]) return "\n".join(ret_vals) - def dict(self, **kwargs) -> dict: + def model_dump(self, **kwargs) -> dict: if 'strictness' in kwargs: kwargs = kwargs.copy() # so if used elsewhere, keeps the kwarg strict_raw = kwargs.pop('strictness') @@ -395,17 +395,17 @@ def dict(self, **kwargs) -> dict: key.name: value for key, value in self.findings.items() } serialized_examples = [ - (ft.dict(**kwargs), (f[0].name, f[1])) for ft, f in self.examples + (ft.model_dump(**kwargs), (f[0].name, f[1])) for ft, f in self.examples # only count if NOT in strictness matrix (i.e 'failures') if f[0] not in STRICTNESS_MATRIX[strictness] ] - model_dict = super().dict(**kwargs) + model_dict = cast(pydantic.BaseModel, super()).model_dump(**kwargs) model_dict['findings'] = serialized_dict model_dict['examples'] = serialized_examples return model_dict def json(self, **kwargs) -> str: - d = self.dict(**kwargs) + d = self.model_dump(**kwargs) return json.dumps(d) @@ -478,7 +478,7 @@ def get_report(self, phrases_separately: bool = False) -> str: for srd in self.per_phrase_results.values()]) return sr + '\n\t\t' + children.replace('\n', '\n\t\t') - def dict(self, **kwargs) -> dict: + def model_dump(self, **kwargs) -> dict: if 'exclude' in kwargs and kwargs['exclude'] is not None: exclude: set = kwargs['exclude'] else: @@ -486,7 +486,7 @@ def dict(self, **kwargs) -> dict: kwargs['exclude'] = exclude # NOTE: ignoring here so that examples are only present in the per phrase part exclude.update(('examples', 'per_phrase_results')) - d = super().dict(**kwargs) + d = cast(pydantic.BaseModel, super()).model_dump(**kwargs) if 'examples' in d: # NOTE: I don't really know why, but the examples still # seem to be a part of the resulting dict, so I need @@ -495,7 +495,7 @@ def dict(self, **kwargs) -> dict: # NOTE: need to propagate here manually so the strictness keyword # makes sense and doesn't cause issues due being to unexpected keyword per_phrase_results = { - phrase: res.dict(**kwargs) for phrase, res in + phrase: res.model_dump(**kwargs) for phrase, res in sorted(self.per_phrase_results.items(), key=lambda it: it[0]) } d['per_phrase_results'] = per_phrase_results @@ -677,7 +677,7 @@ def get_report(self, phrases_separately: bool, ]) return "\n".join(ret_vals) + f"\n{delegated}" - def dict(self, **kwargs) -> dict: + def model_dump(self, **kwargs) -> dict: if 'strictness' in kwargs: strict_raw = kwargs.pop('strictness') if isinstance(strict_raw, Strictness): @@ -688,8 +688,8 @@ def dict(self, **kwargs) -> dict: raise ValueError(f"Unknown stircntess specified: {strict_raw}") else: strictness = Strictness.NORMAL - out_dict = super().dict(exclude={'parts'}, **kwargs) - out_dict['parts'] = [part.dict(strictness=strictness) for part in self.parts] + out_dict = cast(pydantic.BaseModel, super()).model_dump(exclude={'parts'}, **kwargs) + out_dict['parts'] = [part.model_dump(strictness=strictness) for part in self.parts] return out_dict diff --git a/tests/stats/helpers.py b/tests/stats/helpers.py index 80771b11c..af58517bc 100644 --- a/tests/stats/helpers.py +++ b/tests/stats/helpers.py @@ -1,9 +1,13 @@ -from pydantic import create_model_from_typeddict +import pydantic +from unittest import TestCase + +import pydantic.error_wrappers from medcat.stats.mctexport import MedCATTrainerExport -MCTExportPydanticModel = create_model_from_typeddict(MedCATTrainerExport) +MCTExportPydanticModel = pydantic.TypeAdapter(MedCATTrainerExport) + def nullify_doc_names_proj_ids(export: MedCATTrainerExport) -> MedCATTrainerExport: @@ -15,3 +19,8 @@ def nullify_doc_names_proj_ids(export: MedCATTrainerExport) -> MedCATTrainerExpo ], key=lambda doc: doc['id']) } for project in export['projects'] ]} + + +def assert_is_mct_export(tc: TestCase, mct_export: dict): + model_instance = MCTExportPydanticModel.validate_python(mct_export) + tc.assertIsInstance(model_instance, dict) # NOTE: otherwise would have raised an exception diff --git a/tests/stats/test_kfold.py b/tests/stats/test_kfold.py index cae44248c..d06b666ec 100644 --- a/tests/stats/test_kfold.py +++ b/tests/stats/test_kfold.py @@ -5,11 +5,10 @@ from medcat.stats import kfold from medcat.cat import CAT -from pydantic.error_wrappers import ValidationError as PydanticValidationError import unittest -from .helpers import MCTExportPydanticModel, nullify_doc_names_proj_ids +from .helpers import assert_is_mct_export, nullify_doc_names_proj_ids class MCTExportTests(unittest.TestCase): @@ -22,11 +21,7 @@ def setUpClass(cls) -> None: cls.mct_export = json.load(f) def assertIsMCTExport(self, obj): - try: - model = MCTExportPydanticModel(**obj) - except PydanticValidationError as e: - raise AssertionError("Not n MCT export") from e - self.assertIsInstance(model, MCTExportPydanticModel) + assert_is_mct_export(self, obj) class KFoldCreatorTests(MCTExportTests): diff --git a/tests/stats/test_mctexport.py b/tests/stats/test_mctexport.py index 8ef11f556..924bbebf7 100644 --- a/tests/stats/test_mctexport.py +++ b/tests/stats/test_mctexport.py @@ -5,7 +5,7 @@ import unittest -from .helpers import MCTExportPydanticModel +from .helpers import assert_is_mct_export class MCTExportIterationTests(unittest.TestCase): @@ -22,8 +22,7 @@ def setUpClass(cls) -> None: def test_conforms_to_template(self): # NOTE: This uses pydantic to make sure that the MedCATTrainerExport # type matches the actual export format - model_instance = MCTExportPydanticModel(**self.mct_export) - self.assertIsInstance(model_instance, MCTExportPydanticModel) + assert_is_mct_export(self, self.mct_export) def test_iterates_over_all_docs(self): self.assertEqual(mctexport.count_all_docs(self.mct_export), self.EXPECTED_DOCS) diff --git a/tests/utils/regression/test_checking.py b/tests/utils/regression/test_checking.py index 06ca933de..9cfffeeb1 100644 --- a/tests/utils/regression/test_checking.py +++ b/tests/utils/regression/test_checking.py @@ -239,7 +239,7 @@ def setUpClass(cls) -> None: final_phrase='FINAL PHRASE'), finding=(Finding.FOUND_OTHER, 'CUI=OTHER')) def test_result_is_json_serialisable(self): - rd = self.res.dict() + rd = self.res.model_dump() s = json.dumps(rd) self.assertIsInstance(s, str) @@ -249,19 +249,19 @@ def test_result_is_json_serialisable_pydantic(self): def test_can_use_strictness(self): e1 = [ - example for part in self.res.dict(strictness=Strictness.STRICTEST)['parts'] + example for part in self.res.model_dump(strictness=Strictness.STRICTEST)['parts'] for per_phrase in part['per_phrase_results'].values() for example in per_phrase['examples'] ] e2 = [ - example for part in self.res.dict(strictness=Strictness.LENIENT)['parts'] + example for part in self.res.model_dump(strictness=Strictness.LENIENT)['parts'] for per_phrase in part['per_phrase_results'].values() for example in per_phrase['examples'] ] self.assertGreater(len(e1), len(e2)) def test_dict_includes_all_parts(self): - d_parts = self.res.dict()['parts'] + d_parts = self.res.model_dump()['parts'] self.assertEqual(len(self.res.parts), len(d_parts)) diff --git a/tests/utils/regression/test_results.py b/tests/utils/regression/test_results.py index d8e27054a..7033e2860 100644 --- a/tests/utils/regression/test_results.py +++ b/tests/utils/regression/test_results.py @@ -291,13 +291,13 @@ def test_can_json_dump_pydantic(self): self.assertIsInstance(s, str) def test_can_json_dump_json(self): - s = json.dumps(self.rd.dict()) + s = json.dumps(self.rd.model_dump()) self.assertIsInstance(s, str) def test_can_use_strictness_for_dump(self): - d_strictest = self.rd.dict(strictness='STRICTEST') + d_strictest = self.rd.model_dump(strictness='STRICTEST') e_strictest = d_strictest['examples'] # this should have more examples - d_lenient = self.rd.dict(strictness='NORMAL') + d_lenient = self.rd.model_dump(strictness='NORMAL') e_normal = d_lenient['examples'] self.assertGreater(len(e_strictest), len(e_normal))