diff --git a/medkit/core/operation.py b/medkit/core/operation.py index 29057a28..fffc3980 100644 --- a/medkit/core/operation.py +++ b/medkit/core/operation.py @@ -34,22 +34,15 @@ class Operation(abc.ABC): >>> super().__init__(**init_args) """ - uid: str - _description: OperationDescription | None = None _prov_tracer: ProvTracer | None = None @abc.abstractmethod def __init__(self, uid: str | None = None, name: str | None = None, **kwargs): - if uid is None: - uid = generate_id() - if name is None: - name = self.__class__.__name__ - - self.uid = uid + self.uid = uid or generate_id() self._description = OperationDescription( uid=self.uid, class_name=self.__class__.__name__, - name=name, + name=name or self.__class__.__name__, config=kwargs, ) @@ -68,9 +61,9 @@ def description(self) -> OperationDescription: """Contains all the operation init parameters.""" return self._description - def check_sanity(self) -> bool: # noqa: B027 + def check_sanity(self) -> None: # TODO: add some checks - pass + return class DocOperation(Operation): diff --git a/medkit/io/_brat_utils.py b/medkit/io/_brat_utils.py index 7a3cdbdd..572001f5 100644 --- a/medkit/io/_brat_utils.py +++ b/medkit/io/_brat_utils.py @@ -8,6 +8,7 @@ from smart_open import open if TYPE_CHECKING: + from collections.abc import Mapping, Sequence from pathlib import Path GROUPING_ENTITIES = frozenset(["And-Group", "Or-Group"]) @@ -22,7 +23,7 @@ class BratEntity: uid: str type: str - span: list[tuple[int, int]] + span: Sequence[tuple[int, int]] text: str @property @@ -58,7 +59,7 @@ class BratAttribute: uid: str type: str target: str - value: str = None # Only one value is possible + value: str | None = None # Only one value is possible def to_str(self) -> str: value = ensure_attr_value(self.value) @@ -80,7 +81,7 @@ def to_str(self) -> str: def ensure_attr_value(attr_value: Any) -> str: - """Ensure that the attribue value is a string.""" + """Ensure that the attribute value is a string.""" if isinstance(attr_value, str): return attr_value if attr_value is None or isinstance(attr_value, bool): @@ -98,7 +99,7 @@ class Grouping: uid: str type: str - items: list[BratEntity] + items: Sequence[BratEntity] @property def text(self): @@ -111,11 +112,11 @@ class BratAugmentedEntity: uid: str type: str - span: tuple[tuple[int, int], ...] + span: Sequence[tuple[int, int]] text: str - relations_from_me: tuple[BratRelation, ...] - relations_to_me: tuple[BratRelation, ...] - attributes: tuple[BratAttribute, ...] + relations_from_me: Sequence[BratRelation] + relations_to_me: Sequence[BratRelation] + attributes: Sequence[BratAttribute] @property def start(self) -> int: @@ -128,11 +129,11 @@ def end(self) -> int: @dataclass class BratDocument: - entities: dict[str, BratEntity] - relations: dict[str, BratRelation] - attributes: dict[str, BratAttribute] - notes: dict[str, BratNote] - groups: dict[str, Grouping] = None + entities: Mapping[str, BratEntity] + relations: Mapping[str, BratRelation] + attributes: Mapping[str, BratAttribute] + notes: Mapping[str, BratNote] + groups: Mapping[str, Grouping] | None = None def get_augmented_entities(self) -> dict[str, BratAugmentedEntity]: augmented_entities = {} @@ -374,9 +375,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument: logger.warning("Ignore annotation %s at line %s", ann_id, line_number) # Process groups - groups = None if detect_groups: - groups: dict[str, Grouping] = {} + groups = {} grouping_relations = {r.uid: r for r in relations.values() if r.type in GROUPING_RELATIONS} for entity in entities.values(): @@ -385,6 +385,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument: entities[relation.obj] for relation in grouping_relations.values() if relation.subj == entity.uid ] groups[entity.uid] = Grouping(entity.uid, entity.type, items) + else: + groups = None return BratDocument(entities, relations, attributes, notes, groups) diff --git a/medkit/text/ner/umls_utils.py b/medkit/text/ner/umls_utils.py index 2fdde0bb..a9e447f0 100644 --- a/medkit/text/ner/umls_utils.py +++ b/medkit/text/ner/umls_utils.py @@ -155,12 +155,10 @@ def load_umls_entries( if lui in luis_seen: continue - if semtypes_by_cui is not None and cui in semtypes_by_cui: - semtypes = semtypes_by_cui[cui] - semgroups = [semgroups_by_semtype[semtype] for semtype in semtypes] - else: - semtypes = None - semgroups = None + semtypes = semtypes_by_cui.get(cui) if semtypes_by_cui else None + semgroups = ( + [semgroups_by_semtype[semtype] for semtype in semtypes] if semgroups_by_semtype and semtypes else None + ) luis_seen.add(lui) yield UMLSEntry(cui, term, semtypes, semgroups) @@ -198,7 +196,7 @@ def load_semtypes_by_cui(mrsty_file: str | Path) -> dict[str, list[str]]: # Source: UMLS project # https://lhncbc.nlm.nih.gov/semanticnetwork/download/sg_archive/SemGroups-v04.txt _UMLS_SEMGROUPS_FILE = Path(__file__).parent / "umls_semgroups_v04.txt" -_SEMGROUPS_BY_SEMTYPE = None +_SEMGROUPS_BY_SEMTYPE: dict[str, str] | None = None def load_semgroups_by_semtype() -> dict[str, str]: diff --git a/medkit/training/callbacks.py b/medkit/training/callbacks.py index 1299ce57..0f58e7a7 100644 --- a/medkit/training/callbacks.py +++ b/medkit/training/callbacks.py @@ -8,6 +8,8 @@ from tqdm import tqdm if TYPE_CHECKING: + from collections.abc import Mapping + from medkit.training.trainer_config import TrainerConfig @@ -23,7 +25,7 @@ def on_train_end(self): def on_epoch_begin(self, epoch: int): """Event called at the beginning of an epoch.""" - def on_epoch_end(self, metrics: dict[str, float], epoch: int, epoch_time: float): + def on_epoch_end(self, metrics: Mapping[str, Mapping[str, float]], epoch: int, epoch_duration: float): """Event called at the end of an epoch.""" def on_step_begin(self, step_idx: int, nb_batches: int, phase: str): @@ -66,7 +68,7 @@ def on_train_begin(self, config): ) self.logger.info(message) - def on_epoch_end(self, metrics, epoch, epoch_duration): + def on_epoch_end(self, metrics: Mapping[str, Mapping[str, float]], epoch: int, epoch_duration: float): message = f"Epoch {epoch} ended (duration: {epoch_duration:.2f}s)\n" train_metrics = metrics.get("train", None) diff --git a/medkit/training/utils.py b/medkit/training/utils.py index 5e687d20..1ddffb6c 100644 --- a/medkit/training/utils.py +++ b/medkit/training/utils.py @@ -5,7 +5,7 @@ from typing import Any, runtime_checkable import torch -from typing_extensions import Protocol, Self +from typing_extensions import Protocol class BatchData(dict): @@ -17,7 +17,7 @@ def __getitem__(self, index: int) -> dict[str, list[Any] | torch.Tensor]: return inner_dict[index] return {key: values[index] for key, values in self.items()} - def to_device(self, device: torch.device) -> Self: + def to_device(self, device: torch.device) -> BatchData: """Ensure that Tensors in the BatchData object are on the specified `device`. Parameters