Skip to content

Commit

Permalink
CU-8695d4www pydantic 2 (#476)
Browse files Browse the repository at this point in the history
* CU-8695d4www: Bump pydantic requirement to 2.6+

* CU-8695d4www: Update methods to use pydantic2 based ones

* CU-8695d4www: Update methods to use pydantic2 based ones [part 2]

* CU-8695d4www: Use identifier based config when setting last train date on meta cat and tner

* CU-8695d4www: Use pydantic2-based model validation

* CU-8695d4www: Add workarounds for pydantic1 methods

* CU-8695d4www: Add missing utils module for pydantic1 methods

* Revert "CU-8695d4www: Bump pydantic requirement to 2.6+"

This reverts commit b0b3d43.

* CU-8695d4www: [TEMP] Add type-ingores to pydantic2-based methods for GHA workflow

* CU-8695d4www: Make pydantic2-requires getattribute wrapper only apply when appropriate

* CU-8695d4www: Fix missin model dump getter abstraction

* CU-8695d4www: Fix missin model dump getter abstraction (in CAT)

* CU-8695d4www: Update tests for pydantic 1 and 2 support

* Revert "CU-8695d4www: [TEMP] Add type-ingores to pydantic2-based methods for GHA workflow"

This reverts commit b86135a.

* Reapply "CU-8695d4www: Bump pydantic requirement to 2.6+"

This reverts commit 080ae71.

* CU-8695d4www: Allow both pydantic 1 and 2

* CU-8695d4www: Deprecated pydantic utils for removal in 1.15

* CU-8695d4www: Allow usage of specified deprecated method(s) during tests

* CU-8695d4www: Allow usage of pydantic 1-2 workaround methods during tests

* CU-8695d4www: Add documentation for argument allowing usage during tests in deprecation method

* CU-8695d4www: Fix allowing deprecation during test time

* CU-8695d4www: Fix model dump getting in regression checker

* Revert "CU-8695d4www: Fix allowing deprecation during test time"

This reverts commit fadc7d1.

* Revert "CU-8695d4www: Add documentation for argument allowing usage during tests in deprecation method"

This reverts commit 927f807.

* Revert "CU-8695d4www: Allow usage of pydantic 1-2 workaround methods during tests"

This reverts commit 825628e.

* Revert "CU-8695d4www: Allow usage of specified deprecated method(s) during tests"

This reverts commit a89e680.

* Revert "CU-8695d4www: Deprecated pydantic utils for removal in 1.15"

This reverts commit 0ee1a8a.

* CU-8695d4www: Add comment regarding pydantic backwards compatiblity where applicable

* CU-8695d4www: Add pydantic 1 check to GHA workflow

* CU-8695d4www: Fix usage of pydantic-1 based dict method in regression results

* CU-8695d4www: Fix usage of pydantic-1 based dict method in regression tests

* CU-8695d4www: New workflow step to install and run mypy on pydantic 1

* CU-8695d4www: Add type ignore comments to pydantic2 versions in versioning utils for typing during GHA workflow

* CU-8695d4www: Update pydantic requirement to 2.0+ only

* CU-8695d4www: Update to pydantic 2 ONLY

* CU-869671bn4: Update mypy dev requirement to be less than 1.12

* CU-869671bn4: Fix model fields in config

* CU-869671bn4: Fix stats helper method - use correct type adapter

* CU-869671bn4: Fix some model type issues

* CU-869671bn4: Line up with previous model dump methods

* CU-869671bn4: Fix overwriting model dump methods

* CU-869671bn4: Remove pydantic1 workflow step
  • Loading branch information
mart-r authored Nov 27, 2024
1 parent 3c44dcb commit bb41955
Show file tree
Hide file tree
Showing 17 changed files with 86 additions and 76 deletions.
8 changes: 7 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ jobs:
- name: Lint
run: |
flake8 medcat
- name: Pydantic 1 check
# NOTE: the following will look for use of pydantic1-specific .dict() method and .__fields__ attribute
# if there are some (that are not annotated for pydantic1 backwards compatibility) a non-zero exit
# code is returned, which will hald the workflow and print out the offending parts
run: |
grep "\.__fields__" medcat -rI | grep -v "# 4pydantic1 - backwards compatibility" | tee /dev/stderr | test $(wc -l) -eq 0
grep "\.dict(" medcat -rI | grep -v "# 4pydantic1 - backwards compatibility" | tee /dev/stderr | test $(wc -l) -eq 0
- name: Test
run: |
all_files=$(git ls-files | grep '^tests/.*\.py$' | grep -v '/__init__\.py$' | sed 's/\.py$//' | sed 's/\//./g')
Expand All @@ -54,7 +61,6 @@ jobs:
repo: context.repo.repo
});
core.setOutput('latest_version', latestRelease.data.tag_name);
- name: Make sure there's no deprecated methods that should be removed.
# only run this for master -> production PR. I.e just before doing a release.
if: github.event.pull_request.base.ref == 'main' && github.event.pull_request.head.ref == 'production'
Expand Down
2 changes: 1 addition & 1 deletion install_requires.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
'xxhash>=3.0.0' # allow later versions, tested with 3.1.0
'blis>=0.7.5,<1.0.0' # allow later versions, tested with 0.7.9, avoid 1.0.0 (depends on numpy 2)
'click>=8.0.4' # allow later versions, tested with 8.1.3
'pydantic>=1.10.0,<2.0' # for spacy compatibility; avoid 2.0 due to breaking changes
'pydantic>=2.0.0,<3.0' # avoid next major release
"humanfriendly~=10.0" # for human readable file / RAM sizes
"peft>=0.8.2"
2 changes: 1 addition & 1 deletion medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def _print_stats(self,

def _init_ckpts(self, is_resumed, checkpoint):
if self.config.general.checkpoint.steps is not None or checkpoint is not None:
checkpoint_config = CheckpointConfig(**self.config.general.checkpoint.dict())
checkpoint_config = CheckpointConfig(**self.config.general.checkpoint.model_dump())
checkpoint_manager = CheckpointManager('cat_train', checkpoint_config)
if is_resumed:
# TODO: probably remove is_resumed mark and always resume if a checkpoint is provided,
Expand Down
41 changes: 21 additions & 20 deletions medcat/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from datetime import datetime
from pydantic import BaseModel, Extra, ValidationError
from pydantic.fields import ModelField
from pydantic import BaseModel, ValidationError
from typing import List, Set, Tuple, cast, Any, Callable, Dict, Optional, Union, Type, Literal
from multiprocessing import cpu_count
import logging
Expand Down Expand Up @@ -125,7 +124,7 @@ def merge_config(self, config_dict: Dict) -> None:
attr = None # new attribute
value = config_dict[key]
if isinstance(value, BaseModel):
value = value.dict()
value = value.model_dump()
if isinstance(attr, MixingConfig):
attr.merge_config(value)
else:
Expand Down Expand Up @@ -177,7 +176,7 @@ def rebuild_re(self) -> None:
def _calc_hash(self, hasher: Optional[Hasher] = None) -> Hasher:
if hasher is None:
hasher = Hasher()
for _, v in cast(BaseModel, self).dict().items():
for _, v in cast(BaseModel, self).model_dump().items():
if isinstance(v, MixingConfig):
v._calc_hash(hasher)
else:
Expand All @@ -189,7 +188,7 @@ def get_hash(self, hasher: Optional[Hasher] = None):
return hasher.hexdigest()

def __str__(self) -> str:
return str(cast(BaseModel, self).dict())
return str(cast(BaseModel, self).model_dump())

@classmethod
def load(cls, save_path: str) -> "MixingConfig":
Expand Down Expand Up @@ -238,15 +237,15 @@ def asdict(self) -> Dict[str, Any]:
Returns:
Dict[str, Any]: The dictionary associated with this config
"""
return cast(BaseModel, self).dict()
return cast(BaseModel, self).model_dump()

def fields(self) -> Dict[str, ModelField]:
def fields(self) -> dict:
"""Get the fields associated with this config.
Returns:
Dict[str, ModelField]: The dictionary of the field names and fields
dict: The dictionary of the field names and fields
"""
return cast(BaseModel, self).__fields__
return cast(BaseModel, self).model_fields


class VersionInfo(MixingConfig, BaseModel):
Expand All @@ -272,7 +271,7 @@ class VersionInfo(MixingConfig, BaseModel):
"""Which version of medcat was used to build the CDB"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -290,7 +289,7 @@ class CDBMaker(MixingConfig, BaseModel):
"""Minimum number of letters required in a name to be accepted for a concept"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -303,7 +302,7 @@ class AnnotationOutput(MixingConfig, BaseModel):
include_text_in_output: bool = False

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -317,7 +316,7 @@ class CheckPoint(MixingConfig, BaseModel):
"""When training the maximum checkpoints will be kept on the disk"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -354,7 +353,7 @@ class General(MixingConfig, BaseModel):
NB! For these changes to take effect, the pipe would need to be recreated."""
checkpoint: CheckPoint = CheckPoint()
usage_monitor = UsageMonitor()
usage_monitor: UsageMonitor = UsageMonitor()
"""Checkpointing config"""
log_level: int = logging.INFO
"""Logging config for everything | 'tagger' can be disabled, but will cause a drop in performance"""
Expand Down Expand Up @@ -395,7 +394,7 @@ class General(MixingConfig, BaseModel):
reliable due to not taking into account all the details of the changes."""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -424,7 +423,7 @@ class Preprocessing(MixingConfig, BaseModel):
NB! For these changes to take effect, the pipe would need to be recreated."""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -444,7 +443,7 @@ class Ner(MixingConfig, BaseModel):
"""Try reverse word order for short concepts (2 words max), e.g. heart disease -> disease heart"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -579,7 +578,7 @@ class Linking(MixingConfig, BaseModel):
"""If true when the context of a concept is calculated (embedding) the words making that concept are not taken into account"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -600,7 +599,7 @@ class Config:
# this if for word_skipper and punct_checker which would otherwise
# not have a validator
arbitrary_types_allowed = True
extra = Extra.allow
extra = 'allow'
validate_assignment = True

def __init__(self, *args, **kwargs):
Expand All @@ -618,7 +617,7 @@ def rebuild_re(self) -> None:
# Override
def get_hash(self):
hasher = Hasher()
for k, v in self.dict().items():
for k, v in self.model_dump().items():
if k in ['hash', ]:
# ignore hash
continue
Expand Down Expand Up @@ -674,4 +673,6 @@ def wrapper(*args, **kwargs):
# we get a nicer exceptio
_waf_advice = "You can use `cat.cdb.weighted_average_function` to access it directly"
Linking.__getattribute__ = _wrapper(Linking.__getattribute__, Linking, _waf_advice, AttributeError) # type: ignore
if hasattr(Linking, '__getattr__'):
Linking.__getattr__ = _wrapper(Linking.__getattr__, Linking, _waf_advice, AttributeError) # type: ignore
Linking.__getitem__ = _wrapper(Linking.__getitem__, Linking, _waf_advice, KeyError) # type: ignore
12 changes: 6 additions & 6 deletions medcat/config_meta_cat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Any
from medcat.config import MixingConfig, BaseModel, Optional, Extra
from medcat.config import MixingConfig, BaseModel, Optional


class General(MixingConfig, BaseModel):
Expand Down Expand Up @@ -65,7 +65,7 @@ class General(MixingConfig, BaseModel):
Otherwise defaults to doc._.ents or doc.ents per the annotate_overlapping settings"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand Down Expand Up @@ -169,7 +169,7 @@ class Model(MixingConfig, BaseModel):
"""If set to True center positions will be ignored when calculating representation"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -191,7 +191,7 @@ class Train(MixingConfig, BaseModel):
"""If set only this CUIs will be used for training"""
auto_save_model: bool = True
"""Should do model be saved during training for best results"""
last_train_on: Optional[int] = None
last_train_on: Optional[float] = None
"""When was the last training run"""
metric: Dict[str, str] = {'base': 'weighted avg', 'score': 'f1-score'}
"""What metric should be used for choosing the best model"""
Expand All @@ -206,7 +206,7 @@ class Train(MixingConfig, BaseModel):
"""Focal Loss hyperparameter - determines importance the loss gives to hard-to-classify examples"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -217,5 +217,5 @@ class ConfigMetaCAT(MixingConfig, BaseModel):
train: Train = Train()

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True
8 changes: 4 additions & 4 deletions medcat/config_rel_cat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from typing import Dict, Any, List
from medcat.config import MixingConfig, BaseModel, Optional, Extra
from medcat.config import MixingConfig, BaseModel, Optional


class General(MixingConfig, BaseModel):
Expand Down Expand Up @@ -89,7 +89,7 @@ class Model(MixingConfig, BaseModel):
"""If set to True center positions will be ignored when calculating representation"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -116,7 +116,7 @@ class Train(MixingConfig, BaseModel):
"""Should the model be saved during training for best results"""

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -127,5 +127,5 @@ class ConfigRelCAT(MixingConfig, BaseModel):
train: Train = Train()

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True
8 changes: 4 additions & 4 deletions medcat/config_transformers_ner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from medcat.config import MixingConfig, BaseModel, Optional, Extra
from medcat.config import MixingConfig, BaseModel, Optional


class General(MixingConfig, BaseModel):
Expand All @@ -16,11 +16,11 @@ class General(MixingConfig, BaseModel):
chunking_overlap_window: Optional[int] = 5
"""Size of the overlap window used for chunking"""
test_size: float = 0.2
last_train_on: Optional[int] = None
last_train_on: Optional[float] = None
verbose_metrics: bool = False

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True


Expand All @@ -29,5 +29,5 @@ class ConfigTransformersNER(MixingConfig, BaseModel):
general: General = General()

class Config:
extra = Extra.allow
extra = 'allow'
validate_assignment = True
6 changes: 3 additions & 3 deletions medcat/meta_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def get_hash(self) -> str:
"""
hasher = Hasher()
# Set last_train_on if None
if self.config.train['last_train_on'] is None:
self.config.train['last_train_on'] = datetime.now().timestamp()
if self.config.train.last_train_on is None:
self.config.train.last_train_on = datetime.now().timestamp()

hasher.update(self.config.get_hash())
return hasher.hexdigest()
Expand Down Expand Up @@ -310,7 +310,7 @@ def train_raw(self, data_loaded: Dict, save_dir_path: Optional[str] = None, data
# Save everything now
self.save(save_dir_path=save_dir_path)

self.config.train['last_train_on'] = datetime.now().timestamp()
self.config.train.last_train_on = datetime.now().timestamp()
return report

def eval(self, json_path: str) -> Dict:
Expand Down
6 changes: 3 additions & 3 deletions medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def get_hash(self) -> str:
"""
hasher = Hasher()
# Set last_train_on if None
if self.config.general['last_train_on'] is None:
self.config.general['last_train_on'] = datetime.now().timestamp()
if self.config.general.last_train_on is None:
self.config.general.last_train_on = datetime.now().timestamp()

hasher.update(self.config.get_hash())
return hasher.hexdigest()
Expand Down Expand Up @@ -242,7 +242,7 @@ def train(self,
trainer.train() # type: ignore

# Save the training time
self.config.general['last_train_on'] = datetime.now().timestamp() # type: ignore
self.config.general.last_train_on = datetime.now().timestamp() # type: ignore

# Save everything
self.save(save_dir_path=os.path.join(self.training_arguments.output_dir, 'final_model'))
Expand Down
2 changes: 1 addition & 1 deletion medcat/utils/regression/checking.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ def to_dict(self) -> dict:
d = {}
for case in self.cases:
d[case.name] = case.to_dict()
d['meta'] = self.metadata.dict()
d['meta'] = self.metadata.model_dump()
fix_np_float64(d['meta'])

return d
Expand Down
4 changes: 2 additions & 2 deletions medcat/utils/regression/regression_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def main(model_pack_dir: Path, test_suite_file: Path,
examples_strictness = Strictness[examples_strictness_str]
if jsonpath:
logger.info('Writing to %s', str(jsonpath))
jsonpath.write_text(json.dumps(res.dict(strictness=examples_strictness),
indent=jsonindent))
dumped = res.model_dump(strictness=examples_strictness)
jsonpath.write_text(json.dumps(dumped, indent=jsonindent))
else:
logger.info(res.get_report(phrases_separately=phrases,
hide_empty=hide_empty, examples_strictness=examples_strictness,
Expand Down
Loading

0 comments on commit bb41955

Please sign in to comment.