From 74cd82cd23010758b78796d364e9025a51c13b19 Mon Sep 17 00:00:00 2001 From: Szymon Palucha Date: Mon, 25 Nov 2024 16:30:36 +0000 Subject: [PATCH] refactor: move functions to a reusable module --- kazu/training/evaluate_script.py | 11 ++--- kazu/training/modelling_utils.py | 76 ++++++++++++++++++++++++++++++++ kazu/training/predict_script.py | 10 ++--- kazu/training/train_script.py | 64 +++------------------------ 4 files changed, 92 insertions(+), 69 deletions(-) create mode 100644 kazu/training/modelling_utils.py diff --git a/kazu/training/evaluate_script.py b/kazu/training/evaluate_script.py index 5bce2a33..1ebe25ed 100644 --- a/kazu/training/evaluate_script.py +++ b/kazu/training/evaluate_script.py @@ -17,12 +17,16 @@ ) from kazu.steps.ner.tokenized_word_processor import TokenizedWordProcessor from kazu.training.config import PredictionConfig +from kazu.training.modelling_utils import ( + create_wrapper, + doc_yielder, + get_label_list_from_model, +) from kazu.training.train_multilabel_ner import ( _select_keys_to_use, calculate_metrics, move_entities_to_metadata, ) -from kazu.training.train_script import create_wrapper, doc_yielder from kazu.utils.constants import HYDRA_VERSION_BASE @@ -36,10 +40,7 @@ def main(cfg: DictConfig) -> None: prediction_config: PredictionConfig = instantiate(cfg.prediction_config) - with open(Path(prediction_config.path) / "config.json", "r") as file: - model_config = json.load(file) - id2label = {int(idx): label for idx, label in model_config["id2label"].items()} - label_list = [label for _, label in sorted(id2label.items())] + label_list = get_label_list_from_model(Path(prediction_config.path) / "config.json") print(f"There are {len(label_list)} labels.") step = TransformersModelForTokenClassificationNerStep( diff --git a/kazu/training/modelling_utils.py b/kazu/training/modelling_utils.py new file mode 100644 index 00000000..95f33350 --- /dev/null +++ b/kazu/training/modelling_utils.py @@ -0,0 +1,76 @@ +import json +from pathlib import Path +from typing import Iterable, Optional + +from hydra.utils import instantiate +from omegaconf import DictConfig + +from kazu.annotation.label_studio import ( + LabelStudioAnnotationView, + LabelStudioManager, +) +from kazu.data import ENTITY_OUTSIDE_SYMBOL, Document, Entity, Section +from kazu.training.train_multilabel_ner import ( + LSManagerViewWrapper, +) +from kazu.utils.utils import PathLike + + +def doc_yielder(path: PathLike) -> Iterable[Document]: + for file in Path(path).iterdir(): + with file.open(mode="r") as f: + try: + yield Document.from_json(f.read()) + except Exception as e: + print(f"failed to read: {file}, {e}") + + +def test_doc_yielder() -> Iterable[Document]: + section = Section(text="abracodabravir detameth targets BEHATHT.", name="test1") + section.entities.append( + Entity.load_contiguous_entity( + start=0, end=23, match="abracodabravir detameth", entity_class="drug", namespace="test" + ) + ) + section.entities.append( + Entity.load_contiguous_entity( + start=15, end=23, match="detameth", entity_class="salt", namespace="test" + ) + ) + section.entities.append( + Entity.load_contiguous_entity( + start=32, end=39, match="BEHATHT", entity_class="gene", namespace="test" + ) + ) + doc = Document(sections=[section]) + yield doc + + +def get_label_list(path: PathLike) -> list[str]: + label_list = set() + for doc in doc_yielder(path): + for entity in doc.get_entities(): + label_list.add(entity.entity_class) + label_list.add(ENTITY_OUTSIDE_SYMBOL) + # needs deterministic order for consistency + return sorted(list(label_list)) + + +def get_label_list_from_model(model_config_path: PathLike) -> list[str]: + with open(model_config_path, "r") as file: + model_config = json.load(file) + id2label = {int(idx): label for idx, label in model_config["id2label"].items()} + label_list = [label for _, label in sorted(id2label.items())] + return label_list + + +def create_wrapper(cfg: DictConfig, label_list: list[str]) -> Optional[LSManagerViewWrapper]: + if cfg.get("label_studio_manager") and cfg.get("css_colors"): + ls_manager: LabelStudioManager = instantiate(cfg.label_studio_manager) + css_colors = cfg.css_colors + label_to_color = {} + for i, label in enumerate(label_list): + label_to_color[label] = css_colors[i] + view = LabelStudioAnnotationView(ner_labels=label_to_color) + return LSManagerViewWrapper(view, ls_manager) + return None diff --git a/kazu/training/predict_script.py b/kazu/training/predict_script.py index 3a60e645..dbaa03c4 100644 --- a/kazu/training/predict_script.py +++ b/kazu/training/predict_script.py @@ -1,7 +1,6 @@ """Use this script to test the model with custom text inputs and visualize the predictions in Label Studio.""" -import json from pathlib import Path import hydra @@ -15,8 +14,8 @@ ) from kazu.steps.ner.tokenized_word_processor import TokenizedWordProcessor from kazu.training.config import PredictionConfig +from kazu.training.modelling_utils import create_wrapper, get_label_list_from_model from kazu.training.train_multilabel_ner import _select_keys_to_use -from kazu.training.train_script import create_wrapper from kazu.utils.constants import HYDRA_VERSION_BASE @@ -30,10 +29,7 @@ def main(cfg: DictConfig) -> None: prediction_config: PredictionConfig = instantiate(cfg.prediction_config) - with open(Path(prediction_config.path) / "config.json", "r") as file: - model_config = json.load(file) - id2label = {int(idx): label for idx, label in model_config["id2label"].items()} - label_list = [label for _, label in sorted(id2label.items())] + label_list = get_label_list_from_model(Path(prediction_config.path) / "config.json") print(f"There are {len(label_list)} labels.") step = TransformersModelForTokenClassificationNerStep( @@ -51,7 +47,7 @@ def main(cfg: DictConfig) -> None: manager = create_wrapper(cfg, label_list) if manager is not None: - manager.update(documents, 0, has_gs=False) + manager.update(documents, "custom_predictions", has_gs=False) if __name__ == "__main__": diff --git a/kazu/training/train_script.py b/kazu/training/train_script.py index b3ce449d..fac37c75 100644 --- a/kazu/training/train_script.py +++ b/kazu/training/train_script.py @@ -1,7 +1,6 @@ import os from multiprocessing import freeze_support from pathlib import Path -from typing import Iterable, Optional import hydra from hydra.utils import instantiate @@ -11,47 +10,20 @@ from kazu.annotation.label_studio import ( LabelStudioAnnotationView, - LabelStudioManager, ) -from kazu.data import ENTITY_OUTSIDE_SYMBOL, Document, Entity, Section +from kazu.data import ENTITY_OUTSIDE_SYMBOL from kazu.training.config import TrainingConfig +from kazu.training.modelling_utils import ( + create_wrapper, + doc_yielder, + get_label_list, + test_doc_yielder, +) from kazu.training.train_multilabel_ner import ( KazuMultiHotNerMultiLabelTrainingDataset, - LSManagerViewWrapper, Trainer, ) from kazu.utils.constants import HYDRA_VERSION_BASE -from kazu.utils.utils import PathLike - - -def doc_yielder(path: PathLike) -> Iterable[Document]: - for file in Path(path).iterdir(): - with file.open(mode="r") as f: - try: - yield Document.from_json(f.read()) - except Exception as e: - print(f"failed to read: {file}, {e}") - - -def test_doc_yielder() -> Iterable[Document]: - section = Section(text="abracodabravir detameth targets BEHATHT.", name="test1") - section.entities.append( - Entity.load_contiguous_entity( - start=0, end=23, match="abracodabravir detameth", entity_class="drug", namespace="test" - ) - ) - section.entities.append( - Entity.load_contiguous_entity( - start=15, end=23, match="detameth", entity_class="salt", namespace="test" - ) - ) - section.entities.append( - Entity.load_contiguous_entity( - start=32, end=39, match="BEHATHT", entity_class="gene", namespace="test" - ) - ) - doc = Document(sections=[section]) - yield doc def create_view_for_labels( @@ -129,28 +101,6 @@ def run(cfg: DictConfig) -> None: trainer.train_model() -def create_wrapper(cfg: DictConfig, label_list: list[str]) -> Optional[LSManagerViewWrapper]: - if cfg.get("label_studio_manager") and cfg.get("css_colors"): - ls_manager: LabelStudioManager = instantiate(cfg.label_studio_manager) - css_colors = cfg.css_colors - label_to_color = {} - for i, label in enumerate(label_list): - label_to_color[label] = css_colors[i] - view = LabelStudioAnnotationView(ner_labels=label_to_color) - return LSManagerViewWrapper(view, ls_manager) - return None - - -def get_label_list(path: PathLike) -> list[str]: - label_list = set() - for doc in doc_yielder(path): - for entity in doc.get_entities(): - label_list.add(entity.entity_class) - label_list.add(ENTITY_OUTSIDE_SYMBOL) - # needs deterministic order for consistency - return sorted(list(label_list)) - - if __name__ == "__main__": freeze_support() run()