Skip to content

Commit

Permalink
fix: circular import - move funcs to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
paluchasz committed Dec 11, 2024
1 parent cb515d0 commit a8bf7a7
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 53 deletions.
62 changes: 58 additions & 4 deletions kazu/training/modelling_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import copy
import json
import logging
import random
from collections.abc import Iterable
from pathlib import Path
from typing import Any, Iterable, Optional
from typing import Any, Optional, Union

from hydra.utils import instantiate
from omegaconf import DictConfig
Expand All @@ -9,12 +13,17 @@
LabelStudioAnnotationView,
LabelStudioManager,
)
from kazu.data import ENTITY_OUTSIDE_SYMBOL, Document, Entity, Section
from kazu.training.train_multilabel_ner import (
LSManagerViewWrapper,
from kazu.data import (
ENTITY_OUTSIDE_SYMBOL,
PROCESSING_EXCEPTION,
Document,
Entity,
Section,
)
from kazu.utils.utils import PathLike

logger = logging.getLogger(__name__)


def doc_yielder(path: PathLike) -> Iterable[Document]:
for file in Path(path).iterdir():
Expand Down Expand Up @@ -70,6 +79,51 @@ def get_label_list_from_model(model_config_path: PathLike) -> list[str]:
return label_list


class LSManagerViewWrapper:
def __init__(self, view: LabelStudioAnnotationView, ls_manager: LabelStudioManager):
self.ls_manager = ls_manager
self.view = view

def get_gold_ents_for_side_by_side_view(self, docs: list[Document]) -> list[list[Document]]:
result = []
for doc in docs:
doc_cp = copy.deepcopy(doc)
if PROCESSING_EXCEPTION in doc_cp.metadata:
logger.error(doc.metadata[PROCESSING_EXCEPTION])
break
for section in doc_cp.sections:
gold_ents = []
for ent in section.metadata.get("gold_entities", []):
if isinstance(ent, dict):
ent = Entity.from_dict(ent)
gold_ents.append(ent)
section.entities = gold_ents
result.append([doc_cp, doc])
return result

def update(
self, test_docs: list[Document], global_step: Union[int, str], has_gs: bool = True
) -> None:
ls_manager = LabelStudioManager(
headers=self.ls_manager.headers,
project_name=f"{self.ls_manager.project_name}_test_{global_step}",
)

ls_manager.delete_project_if_exists()
ls_manager.create_linking_project()
docs_subset = random.sample(test_docs, min([len(test_docs), 100]))
if not docs_subset:
logger.info("no results to represent yet")
return
if has_gs:
side_by_side = self.get_gold_ents_for_side_by_side_view(docs_subset)
ls_manager.update_view(self.view, side_by_side)
ls_manager.update_tasks(side_by_side)
else:
ls_manager.update_view(self.view, docs_subset)
ls_manager.update_tasks(docs_subset)


def create_wrapper(cfg: DictConfig, label_list: list[str]) -> Optional[LSManagerViewWrapper]:
if cfg.get("label_studio_manager") and cfg.get("css_colors"):
ls_manager: LabelStudioManager = instantiate(cfg.label_studio_manager)
Expand Down
50 changes: 1 addition & 49 deletions kazu/training/train_multilabel_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import math
import pickle
import random
import shutil
import tempfile
from collections import defaultdict
Expand All @@ -27,12 +26,10 @@
)

from kazu.annotation.acceptance_test import aggregate_ner_results, score_sections
from kazu.annotation.label_studio import LabelStudioAnnotationView, LabelStudioManager
from kazu.data import (
ENTITY_OUTSIDE_SYMBOL,
PROCESSING_EXCEPTION,
Document,
Entity,
NumericMetric,
Section,
)
Expand All @@ -47,56 +44,11 @@
DebertaForMultiLabelTokenClassification,
DistilBertForMultiLabelTokenClassification,
)
from kazu.training.modelling_utils import chunks
from kazu.training.modelling_utils import LSManagerViewWrapper, chunks

logger = logging.getLogger(__name__)


class LSManagerViewWrapper:
def __init__(self, view: LabelStudioAnnotationView, ls_manager: LabelStudioManager):
self.ls_manager = ls_manager
self.view = view

def get_gold_ents_for_side_by_side_view(self, docs: list[Document]) -> list[list[Document]]:
result = []
for doc in docs:
doc_cp = copy.deepcopy(doc)
if PROCESSING_EXCEPTION in doc_cp.metadata:
logger.error(doc.metadata[PROCESSING_EXCEPTION])
break
for section in doc_cp.sections:
gold_ents = []
for ent in section.metadata.get("gold_entities", []):
if isinstance(ent, dict):
ent = Entity.from_dict(ent)
gold_ents.append(ent)
section.entities = gold_ents
result.append([doc_cp, doc])
return result

def update(
self, test_docs: list[Document], global_step: Union[int, str], has_gs: bool = True
) -> None:
ls_manager = LabelStudioManager(
headers=self.ls_manager.headers,
project_name=f"{self.ls_manager.project_name}_test_{global_step}",
)

ls_manager.delete_project_if_exists()
ls_manager.create_linking_project()
docs_subset = random.sample(test_docs, min([len(test_docs), 100]))
if not docs_subset:
logger.info("no results to represent yet")
return
if has_gs:
side_by_side = self.get_gold_ents_for_side_by_side_view(docs_subset)
ls_manager.update_view(self.view, side_by_side)
ls_manager.update_tasks(side_by_side)
else:
ls_manager.update_view(self.view, docs_subset)
ls_manager.update_tasks(docs_subset)


@dataclasses.dataclass
class SavedModel:
path: Path
Expand Down

0 comments on commit a8bf7a7

Please sign in to comment.