diff --git a/kazu/training/modelling_utils.py b/kazu/training/modelling_utils.py index 2f11bfea..5e7347b6 100644 --- a/kazu/training/modelling_utils.py +++ b/kazu/training/modelling_utils.py @@ -1,6 +1,10 @@ +import copy import json +import logging +import random +from collections.abc import Iterable from pathlib import Path -from typing import Any, Iterable, Optional +from typing import Any, Optional, Union from hydra.utils import instantiate from omegaconf import DictConfig @@ -9,12 +13,17 @@ LabelStudioAnnotationView, LabelStudioManager, ) -from kazu.data import ENTITY_OUTSIDE_SYMBOL, Document, Entity, Section -from kazu.training.train_multilabel_ner import ( - LSManagerViewWrapper, +from kazu.data import ( + ENTITY_OUTSIDE_SYMBOL, + PROCESSING_EXCEPTION, + Document, + Entity, + Section, ) from kazu.utils.utils import PathLike +logger = logging.getLogger(__name__) + def doc_yielder(path: PathLike) -> Iterable[Document]: for file in Path(path).iterdir(): @@ -70,6 +79,51 @@ def get_label_list_from_model(model_config_path: PathLike) -> list[str]: return label_list +class LSManagerViewWrapper: + def __init__(self, view: LabelStudioAnnotationView, ls_manager: LabelStudioManager): + self.ls_manager = ls_manager + self.view = view + + def get_gold_ents_for_side_by_side_view(self, docs: list[Document]) -> list[list[Document]]: + result = [] + for doc in docs: + doc_cp = copy.deepcopy(doc) + if PROCESSING_EXCEPTION in doc_cp.metadata: + logger.error(doc.metadata[PROCESSING_EXCEPTION]) + break + for section in doc_cp.sections: + gold_ents = [] + for ent in section.metadata.get("gold_entities", []): + if isinstance(ent, dict): + ent = Entity.from_dict(ent) + gold_ents.append(ent) + section.entities = gold_ents + result.append([doc_cp, doc]) + return result + + def update( + self, test_docs: list[Document], global_step: Union[int, str], has_gs: bool = True + ) -> None: + ls_manager = LabelStudioManager( + headers=self.ls_manager.headers, + project_name=f"{self.ls_manager.project_name}_test_{global_step}", + ) + + ls_manager.delete_project_if_exists() + ls_manager.create_linking_project() + docs_subset = random.sample(test_docs, min([len(test_docs), 100])) + if not docs_subset: + logger.info("no results to represent yet") + return + if has_gs: + side_by_side = self.get_gold_ents_for_side_by_side_view(docs_subset) + ls_manager.update_view(self.view, side_by_side) + ls_manager.update_tasks(side_by_side) + else: + ls_manager.update_view(self.view, docs_subset) + ls_manager.update_tasks(docs_subset) + + def create_wrapper(cfg: DictConfig, label_list: list[str]) -> Optional[LSManagerViewWrapper]: if cfg.get("label_studio_manager") and cfg.get("css_colors"): ls_manager: LabelStudioManager = instantiate(cfg.label_studio_manager) diff --git a/kazu/training/train_multilabel_ner.py b/kazu/training/train_multilabel_ner.py index ad8922f0..715fb8cf 100644 --- a/kazu/training/train_multilabel_ner.py +++ b/kazu/training/train_multilabel_ner.py @@ -4,7 +4,6 @@ import logging import math import pickle -import random import shutil import tempfile from collections import defaultdict @@ -27,12 +26,10 @@ ) from kazu.annotation.acceptance_test import aggregate_ner_results, score_sections -from kazu.annotation.label_studio import LabelStudioAnnotationView, LabelStudioManager from kazu.data import ( ENTITY_OUTSIDE_SYMBOL, PROCESSING_EXCEPTION, Document, - Entity, NumericMetric, Section, ) @@ -47,56 +44,11 @@ DebertaForMultiLabelTokenClassification, DistilBertForMultiLabelTokenClassification, ) -from kazu.training.modelling_utils import chunks +from kazu.training.modelling_utils import LSManagerViewWrapper, chunks logger = logging.getLogger(__name__) -class LSManagerViewWrapper: - def __init__(self, view: LabelStudioAnnotationView, ls_manager: LabelStudioManager): - self.ls_manager = ls_manager - self.view = view - - def get_gold_ents_for_side_by_side_view(self, docs: list[Document]) -> list[list[Document]]: - result = [] - for doc in docs: - doc_cp = copy.deepcopy(doc) - if PROCESSING_EXCEPTION in doc_cp.metadata: - logger.error(doc.metadata[PROCESSING_EXCEPTION]) - break - for section in doc_cp.sections: - gold_ents = [] - for ent in section.metadata.get("gold_entities", []): - if isinstance(ent, dict): - ent = Entity.from_dict(ent) - gold_ents.append(ent) - section.entities = gold_ents - result.append([doc_cp, doc]) - return result - - def update( - self, test_docs: list[Document], global_step: Union[int, str], has_gs: bool = True - ) -> None: - ls_manager = LabelStudioManager( - headers=self.ls_manager.headers, - project_name=f"{self.ls_manager.project_name}_test_{global_step}", - ) - - ls_manager.delete_project_if_exists() - ls_manager.create_linking_project() - docs_subset = random.sample(test_docs, min([len(test_docs), 100])) - if not docs_subset: - logger.info("no results to represent yet") - return - if has_gs: - side_by_side = self.get_gold_ents_for_side_by_side_view(docs_subset) - ls_manager.update_view(self.view, side_by_side) - ls_manager.update_tasks(side_by_side) - else: - ls_manager.update_view(self.view, docs_subset) - ls_manager.update_tasks(docs_subset) - - @dataclasses.dataclass class SavedModel: path: Path