Skip to content

Commit

Permalink
refactor: move functions to a reusable module
Browse files Browse the repository at this point in the history
  • Loading branch information
paluchasz committed Nov 25, 2024
1 parent 5734563 commit 74cd82c
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 69 deletions.
11 changes: 6 additions & 5 deletions kazu/training/evaluate_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@
)
from kazu.steps.ner.tokenized_word_processor import TokenizedWordProcessor
from kazu.training.config import PredictionConfig
from kazu.training.modelling_utils import (
create_wrapper,
doc_yielder,
get_label_list_from_model,
)
from kazu.training.train_multilabel_ner import (
_select_keys_to_use,
calculate_metrics,
move_entities_to_metadata,
)
from kazu.training.train_script import create_wrapper, doc_yielder
from kazu.utils.constants import HYDRA_VERSION_BASE


Expand All @@ -36,10 +40,7 @@
def main(cfg: DictConfig) -> None:
prediction_config: PredictionConfig = instantiate(cfg.prediction_config)

with open(Path(prediction_config.path) / "config.json", "r") as file:
model_config = json.load(file)
id2label = {int(idx): label for idx, label in model_config["id2label"].items()}
label_list = [label for _, label in sorted(id2label.items())]
label_list = get_label_list_from_model(Path(prediction_config.path) / "config.json")
print(f"There are {len(label_list)} labels.")

step = TransformersModelForTokenClassificationNerStep(
Expand Down
76 changes: 76 additions & 0 deletions kazu/training/modelling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import json
from pathlib import Path
from typing import Iterable, Optional

from hydra.utils import instantiate
from omegaconf import DictConfig

from kazu.annotation.label_studio import (
LabelStudioAnnotationView,
LabelStudioManager,
)
from kazu.data import ENTITY_OUTSIDE_SYMBOL, Document, Entity, Section
from kazu.training.train_multilabel_ner import (
LSManagerViewWrapper,
)
from kazu.utils.utils import PathLike


def doc_yielder(path: PathLike) -> Iterable[Document]:
for file in Path(path).iterdir():
with file.open(mode="r") as f:
try:
yield Document.from_json(f.read())
except Exception as e:
print(f"failed to read: {file}, {e}")


def test_doc_yielder() -> Iterable[Document]:
section = Section(text="abracodabravir detameth targets BEHATHT.", name="test1")
section.entities.append(
Entity.load_contiguous_entity(
start=0, end=23, match="abracodabravir detameth", entity_class="drug", namespace="test"
)
)
section.entities.append(
Entity.load_contiguous_entity(
start=15, end=23, match="detameth", entity_class="salt", namespace="test"
)
)
section.entities.append(
Entity.load_contiguous_entity(
start=32, end=39, match="BEHATHT", entity_class="gene", namespace="test"
)
)
doc = Document(sections=[section])
yield doc


def get_label_list(path: PathLike) -> list[str]:
label_list = set()
for doc in doc_yielder(path):
for entity in doc.get_entities():
label_list.add(entity.entity_class)
label_list.add(ENTITY_OUTSIDE_SYMBOL)
# needs deterministic order for consistency
return sorted(list(label_list))


def get_label_list_from_model(model_config_path: PathLike) -> list[str]:
with open(model_config_path, "r") as file:
model_config = json.load(file)
id2label = {int(idx): label for idx, label in model_config["id2label"].items()}
label_list = [label for _, label in sorted(id2label.items())]
return label_list


def create_wrapper(cfg: DictConfig, label_list: list[str]) -> Optional[LSManagerViewWrapper]:
if cfg.get("label_studio_manager") and cfg.get("css_colors"):
ls_manager: LabelStudioManager = instantiate(cfg.label_studio_manager)
css_colors = cfg.css_colors
label_to_color = {}
for i, label in enumerate(label_list):
label_to_color[label] = css_colors[i]
view = LabelStudioAnnotationView(ner_labels=label_to_color)
return LSManagerViewWrapper(view, ls_manager)
return None
10 changes: 3 additions & 7 deletions kazu/training/predict_script.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Use this script to test the model with custom text inputs and visualize the
predictions in Label Studio."""

import json
from pathlib import Path

import hydra
Expand All @@ -15,8 +14,8 @@
)
from kazu.steps.ner.tokenized_word_processor import TokenizedWordProcessor
from kazu.training.config import PredictionConfig
from kazu.training.modelling_utils import create_wrapper, get_label_list_from_model
from kazu.training.train_multilabel_ner import _select_keys_to_use
from kazu.training.train_script import create_wrapper
from kazu.utils.constants import HYDRA_VERSION_BASE


Expand All @@ -30,10 +29,7 @@
def main(cfg: DictConfig) -> None:
prediction_config: PredictionConfig = instantiate(cfg.prediction_config)

with open(Path(prediction_config.path) / "config.json", "r") as file:
model_config = json.load(file)
id2label = {int(idx): label for idx, label in model_config["id2label"].items()}
label_list = [label for _, label in sorted(id2label.items())]
label_list = get_label_list_from_model(Path(prediction_config.path) / "config.json")
print(f"There are {len(label_list)} labels.")

step = TransformersModelForTokenClassificationNerStep(
Expand All @@ -51,7 +47,7 @@ def main(cfg: DictConfig) -> None:

manager = create_wrapper(cfg, label_list)
if manager is not None:
manager.update(documents, 0, has_gs=False)
manager.update(documents, "custom_predictions", has_gs=False)


if __name__ == "__main__":
Expand Down
64 changes: 7 additions & 57 deletions kazu/training/train_script.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from multiprocessing import freeze_support
from pathlib import Path
from typing import Iterable, Optional

import hydra
from hydra.utils import instantiate
Expand All @@ -11,47 +10,20 @@

from kazu.annotation.label_studio import (
LabelStudioAnnotationView,
LabelStudioManager,
)
from kazu.data import ENTITY_OUTSIDE_SYMBOL, Document, Entity, Section
from kazu.data import ENTITY_OUTSIDE_SYMBOL
from kazu.training.config import TrainingConfig
from kazu.training.modelling_utils import (
create_wrapper,
doc_yielder,
get_label_list,
test_doc_yielder,
)
from kazu.training.train_multilabel_ner import (
KazuMultiHotNerMultiLabelTrainingDataset,
LSManagerViewWrapper,
Trainer,
)
from kazu.utils.constants import HYDRA_VERSION_BASE
from kazu.utils.utils import PathLike


def doc_yielder(path: PathLike) -> Iterable[Document]:
for file in Path(path).iterdir():
with file.open(mode="r") as f:
try:
yield Document.from_json(f.read())
except Exception as e:
print(f"failed to read: {file}, {e}")


def test_doc_yielder() -> Iterable[Document]:
section = Section(text="abracodabravir detameth targets BEHATHT.", name="test1")
section.entities.append(
Entity.load_contiguous_entity(
start=0, end=23, match="abracodabravir detameth", entity_class="drug", namespace="test"
)
)
section.entities.append(
Entity.load_contiguous_entity(
start=15, end=23, match="detameth", entity_class="salt", namespace="test"
)
)
section.entities.append(
Entity.load_contiguous_entity(
start=32, end=39, match="BEHATHT", entity_class="gene", namespace="test"
)
)
doc = Document(sections=[section])
yield doc


def create_view_for_labels(
Expand Down Expand Up @@ -129,28 +101,6 @@ def run(cfg: DictConfig) -> None:
trainer.train_model()


def create_wrapper(cfg: DictConfig, label_list: list[str]) -> Optional[LSManagerViewWrapper]:
if cfg.get("label_studio_manager") and cfg.get("css_colors"):
ls_manager: LabelStudioManager = instantiate(cfg.label_studio_manager)
css_colors = cfg.css_colors
label_to_color = {}
for i, label in enumerate(label_list):
label_to_color[label] = css_colors[i]
view = LabelStudioAnnotationView(ner_labels=label_to_color)
return LSManagerViewWrapper(view, ls_manager)
return None


def get_label_list(path: PathLike) -> list[str]:
label_list = set()
for doc in doc_yielder(path):
for entity in doc.get_entities():
label_list.add(entity.entity_class)
label_list.add(ENTITY_OUTSIDE_SYMBOL)
# needs deterministic order for consistency
return sorted(list(label_list))


if __name__ == "__main__":
freeze_support()
run()

0 comments on commit 74cd82c

Please sign in to comment.