From 6a4f17c9b31aa9e4f88660be450389eed0c46c51 Mon Sep 17 00:00:00 2001 From: Ghislain Vaillant Date: Thu, 12 Sep 2024 16:14:39 +0200 Subject: [PATCH] WIP: Fix some type checking errors --- medkit/io/_brat_utils.py | 32 +++++++++++++++++--------------- medkit/training/utils.py | 4 ++-- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/medkit/io/_brat_utils.py b/medkit/io/_brat_utils.py index 7a3cdbdd..572001f5 100644 --- a/medkit/io/_brat_utils.py +++ b/medkit/io/_brat_utils.py @@ -8,6 +8,7 @@ from smart_open import open if TYPE_CHECKING: + from collections.abc import Mapping, Sequence from pathlib import Path GROUPING_ENTITIES = frozenset(["And-Group", "Or-Group"]) @@ -22,7 +23,7 @@ class BratEntity: uid: str type: str - span: list[tuple[int, int]] + span: Sequence[tuple[int, int]] text: str @property @@ -58,7 +59,7 @@ class BratAttribute: uid: str type: str target: str - value: str = None # Only one value is possible + value: str | None = None # Only one value is possible def to_str(self) -> str: value = ensure_attr_value(self.value) @@ -80,7 +81,7 @@ def to_str(self) -> str: def ensure_attr_value(attr_value: Any) -> str: - """Ensure that the attribue value is a string.""" + """Ensure that the attribute value is a string.""" if isinstance(attr_value, str): return attr_value if attr_value is None or isinstance(attr_value, bool): @@ -98,7 +99,7 @@ class Grouping: uid: str type: str - items: list[BratEntity] + items: Sequence[BratEntity] @property def text(self): @@ -111,11 +112,11 @@ class BratAugmentedEntity: uid: str type: str - span: tuple[tuple[int, int], ...] + span: Sequence[tuple[int, int]] text: str - relations_from_me: tuple[BratRelation, ...] - relations_to_me: tuple[BratRelation, ...] - attributes: tuple[BratAttribute, ...] + relations_from_me: Sequence[BratRelation] + relations_to_me: Sequence[BratRelation] + attributes: Sequence[BratAttribute] @property def start(self) -> int: @@ -128,11 +129,11 @@ def end(self) -> int: @dataclass class BratDocument: - entities: dict[str, BratEntity] - relations: dict[str, BratRelation] - attributes: dict[str, BratAttribute] - notes: dict[str, BratNote] - groups: dict[str, Grouping] = None + entities: Mapping[str, BratEntity] + relations: Mapping[str, BratRelation] + attributes: Mapping[str, BratAttribute] + notes: Mapping[str, BratNote] + groups: Mapping[str, Grouping] | None = None def get_augmented_entities(self) -> dict[str, BratAugmentedEntity]: augmented_entities = {} @@ -374,9 +375,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument: logger.warning("Ignore annotation %s at line %s", ann_id, line_number) # Process groups - groups = None if detect_groups: - groups: dict[str, Grouping] = {} + groups = {} grouping_relations = {r.uid: r for r in relations.values() if r.type in GROUPING_RELATIONS} for entity in entities.values(): @@ -385,6 +385,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument: entities[relation.obj] for relation in grouping_relations.values() if relation.subj == entity.uid ] groups[entity.uid] = Grouping(entity.uid, entity.type, items) + else: + groups = None return BratDocument(entities, relations, attributes, notes, groups) diff --git a/medkit/training/utils.py b/medkit/training/utils.py index 5e687d20..1ddffb6c 100644 --- a/medkit/training/utils.py +++ b/medkit/training/utils.py @@ -5,7 +5,7 @@ from typing import Any, runtime_checkable import torch -from typing_extensions import Protocol, Self +from typing_extensions import Protocol class BatchData(dict): @@ -17,7 +17,7 @@ def __getitem__(self, index: int) -> dict[str, list[Any] | torch.Tensor]: return inner_dict[index] return {key: values[index] for key, values in self.items()} - def to_device(self, device: torch.device) -> Self: + def to_device(self, device: torch.device) -> BatchData: """Ensure that Tensors in the BatchData object are on the specified `device`. Parameters