Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: Enable type checks with mypy #62

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions medkit/core/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,15 @@ class Operation(abc.ABC):
>>> super().__init__(**init_args)
"""

uid: str
_description: OperationDescription | None = None
_prov_tracer: ProvTracer | None = None

@abc.abstractmethod
def __init__(self, uid: str | None = None, name: str | None = None, **kwargs):
if uid is None:
uid = generate_id()
if name is None:
name = self.__class__.__name__

self.uid = uid
self.uid = uid or generate_id()
self._description = OperationDescription(
uid=self.uid,
class_name=self.__class__.__name__,
name=name,
name=name or self.__class__.__name__,
config=kwargs,
)

Expand All @@ -68,9 +61,9 @@ def description(self) -> OperationDescription:
"""Contains all the operation init parameters."""
return self._description

def check_sanity(self) -> bool: # noqa: B027
def check_sanity(self) -> None:
# TODO: add some checks
pass
return


class DocOperation(Operation):
Expand Down
32 changes: 17 additions & 15 deletions medkit/io/_brat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from smart_open import open

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from pathlib import Path

GROUPING_ENTITIES = frozenset(["And-Group", "Or-Group"])
Expand All @@ -22,7 +23,7 @@ class BratEntity:

uid: str
type: str
span: list[tuple[int, int]]
span: Sequence[tuple[int, int]]
text: str

@property
Expand Down Expand Up @@ -58,7 +59,7 @@ class BratAttribute:
uid: str
type: str
target: str
value: str = None # Only one value is possible
value: str | None = None # Only one value is possible

def to_str(self) -> str:
value = ensure_attr_value(self.value)
Expand All @@ -80,7 +81,7 @@ def to_str(self) -> str:


def ensure_attr_value(attr_value: Any) -> str:
"""Ensure that the attribue value is a string."""
"""Ensure that the attribute value is a string."""
if isinstance(attr_value, str):
return attr_value
if attr_value is None or isinstance(attr_value, bool):
Expand All @@ -98,7 +99,7 @@ class Grouping:

uid: str
type: str
items: list[BratEntity]
items: Sequence[BratEntity]

@property
def text(self):
Expand All @@ -111,11 +112,11 @@ class BratAugmentedEntity:

uid: str
type: str
span: tuple[tuple[int, int], ...]
span: Sequence[tuple[int, int]]
text: str
relations_from_me: tuple[BratRelation, ...]
relations_to_me: tuple[BratRelation, ...]
attributes: tuple[BratAttribute, ...]
relations_from_me: Sequence[BratRelation]
relations_to_me: Sequence[BratRelation]
attributes: Sequence[BratAttribute]

@property
def start(self) -> int:
Expand All @@ -128,11 +129,11 @@ def end(self) -> int:

@dataclass
class BratDocument:
entities: dict[str, BratEntity]
relations: dict[str, BratRelation]
attributes: dict[str, BratAttribute]
notes: dict[str, BratNote]
groups: dict[str, Grouping] = None
entities: Mapping[str, BratEntity]
relations: Mapping[str, BratRelation]
attributes: Mapping[str, BratAttribute]
notes: Mapping[str, BratNote]
groups: Mapping[str, Grouping] | None = None

def get_augmented_entities(self) -> dict[str, BratAugmentedEntity]:
augmented_entities = {}
Expand Down Expand Up @@ -374,9 +375,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument:
logger.warning("Ignore annotation %s at line %s", ann_id, line_number)

# Process groups
groups = None
if detect_groups:
groups: dict[str, Grouping] = {}
groups = {}
grouping_relations = {r.uid: r for r in relations.values() if r.type in GROUPING_RELATIONS}

for entity in entities.values():
Expand All @@ -385,6 +385,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument:
entities[relation.obj] for relation in grouping_relations.values() if relation.subj == entity.uid
]
groups[entity.uid] = Grouping(entity.uid, entity.type, items)
else:
groups = None

return BratDocument(entities, relations, attributes, notes, groups)

Expand Down
12 changes: 5 additions & 7 deletions medkit/text/ner/umls_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,10 @@ def load_umls_entries(
if lui in luis_seen:
continue

if semtypes_by_cui is not None and cui in semtypes_by_cui:
semtypes = semtypes_by_cui[cui]
semgroups = [semgroups_by_semtype[semtype] for semtype in semtypes]
else:
semtypes = None
semgroups = None
semtypes = semtypes_by_cui.get(cui) if semtypes_by_cui else None
semgroups = (
[semgroups_by_semtype[semtype] for semtype in semtypes] if semgroups_by_semtype and semtypes else None
)

luis_seen.add(lui)
yield UMLSEntry(cui, term, semtypes, semgroups)
Expand Down Expand Up @@ -198,7 +196,7 @@ def load_semtypes_by_cui(mrsty_file: str | Path) -> dict[str, list[str]]:
# Source: UMLS project
# https://lhncbc.nlm.nih.gov/semanticnetwork/download/sg_archive/SemGroups-v04.txt
_UMLS_SEMGROUPS_FILE = Path(__file__).parent / "umls_semgroups_v04.txt"
_SEMGROUPS_BY_SEMTYPE = None
_SEMGROUPS_BY_SEMTYPE: dict[str, str] | None = None


def load_semgroups_by_semtype() -> dict[str, str]:
Expand Down
6 changes: 4 additions & 2 deletions medkit/training/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from tqdm import tqdm

if TYPE_CHECKING:
from collections.abc import Mapping

from medkit.training.trainer_config import TrainerConfig


Expand All @@ -23,7 +25,7 @@ def on_train_end(self):
def on_epoch_begin(self, epoch: int):
"""Event called at the beginning of an epoch."""

def on_epoch_end(self, metrics: dict[str, float], epoch: int, epoch_time: float):
def on_epoch_end(self, metrics: Mapping[str, Mapping[str, float]], epoch: int, epoch_duration: float):
"""Event called at the end of an epoch."""

def on_step_begin(self, step_idx: int, nb_batches: int, phase: str):
Expand Down Expand Up @@ -66,7 +68,7 @@ def on_train_begin(self, config):
)
self.logger.info(message)

def on_epoch_end(self, metrics, epoch, epoch_duration):
def on_epoch_end(self, metrics: Mapping[str, Mapping[str, float]], epoch: int, epoch_duration: float):
message = f"Epoch {epoch} ended (duration: {epoch_duration:.2f}s)\n"

train_metrics = metrics.get("train", None)
Expand Down
4 changes: 2 additions & 2 deletions medkit/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, runtime_checkable

import torch
from typing_extensions import Protocol, Self
from typing_extensions import Protocol


class BatchData(dict):
Expand All @@ -17,7 +17,7 @@ def __getitem__(self, index: int) -> dict[str, list[Any] | torch.Tensor]:
return inner_dict[index]
return {key: values[index] for key, values in self.items()}

def to_device(self, device: torch.device) -> Self:
def to_device(self, device: torch.device) -> BatchData:
"""Ensure that Tensors in the BatchData object are on the specified `device`.

Parameters
Expand Down