Skip to content

Commit

Permalink
WIP: Fix some type checking errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ghisvail committed Sep 12, 2024
1 parent 62bb242 commit f7e1e62
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 16 deletions.
32 changes: 17 additions & 15 deletions medkit/io/_brat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import logging
from collections import Counter, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, NamedTuple

Expand All @@ -22,7 +23,7 @@ class BratEntity:

uid: str
type: str
span: list[tuple[int, int]]
span: Sequence[tuple[int, int]]
text: str

@property
Expand Down Expand Up @@ -58,7 +59,7 @@ class BratAttribute:
uid: str
type: str
target: str
value: str = None # Only one value is possible
value: str | None = None # Only one value is possible

def to_str(self) -> str:
value = ensure_attr_value(self.value)
Expand All @@ -80,7 +81,7 @@ def to_str(self) -> str:


def ensure_attr_value(attr_value: Any) -> str:
"""Ensure that the attribue value is a string."""
"""Ensure that the attribute value is a string."""
if isinstance(attr_value, str):
return attr_value
if attr_value is None or isinstance(attr_value, bool):
Expand All @@ -98,7 +99,7 @@ class Grouping:

uid: str
type: str
items: list[BratEntity]
items: Sequence[BratEntity]

@property
def text(self):
Expand All @@ -111,11 +112,11 @@ class BratAugmentedEntity:

uid: str
type: str
span: tuple[tuple[int, int], ...]
span: Sequence[tuple[int, int]]
text: str
relations_from_me: tuple[BratRelation, ...]
relations_to_me: tuple[BratRelation, ...]
attributes: tuple[BratAttribute, ...]
relations_from_me: Sequence[BratRelation]
relations_to_me: Sequence[BratRelation]
attributes: Sequence[BratAttribute]

@property
def start(self) -> int:
Expand All @@ -128,11 +129,11 @@ def end(self) -> int:

@dataclass
class BratDocument:
entities: dict[str, BratEntity]
relations: dict[str, BratRelation]
attributes: dict[str, BratAttribute]
notes: dict[str, BratNote]
groups: dict[str, Grouping] = None
entities: Mapping[str, BratEntity]
relations: Mapping[str, BratRelation]
attributes: Mapping[str, BratAttribute]
notes: Mapping[str, BratNote]
groups: Mapping[str, Grouping] | None = None

def get_augmented_entities(self) -> dict[str, BratAugmentedEntity]:
augmented_entities = {}
Expand Down Expand Up @@ -374,9 +375,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument:
logger.warning("Ignore annotation %s at line %s", ann_id, line_number)

# Process groups
groups = None
if detect_groups:
groups: dict[str, Grouping] = {}
groups = {}
grouping_relations = {r.uid: r for r in relations.values() if r.type in GROUPING_RELATIONS}

for entity in entities.values():
Expand All @@ -385,6 +385,8 @@ def parse_string(ann_string: str, detect_groups: bool = False) -> BratDocument:
entities[relation.obj] for relation in grouping_relations.values() if relation.subj == entity.uid
]
groups[entity.uid] = Grouping(entity.uid, entity.type, items)
else:
groups = None

return BratDocument(entities, relations, attributes, notes, groups)

Expand Down
2 changes: 1 addition & 1 deletion medkit/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __getitem__(self, index: int) -> dict[str, list[Any] | torch.Tensor]:
return inner_dict[index]
return {key: values[index] for key, values in self.items()}

def to_device(self, device: torch.device) -> Self:
def to_device(self, device: torch.device) -> BatchData:
"""Ensure that Tensors in the BatchData object are on the specified `device`.
Parameters
Expand Down

0 comments on commit f7e1e62

Please sign in to comment.