Skip to content

Commit

Permalink
CU-8696v2j42: Add test to make sure per-cui counts are kept when crea…
Browse files Browse the repository at this point in the history
…ting folds (#508)

* CU-8696v2j42: Add test to make sure per-cui counts are kept when creating folds

* CU-8696v2j42: Fix per annotation fold creation
  • Loading branch information
mart-r authored Dec 9, 2024
1 parent bb41955 commit 00c0dd0
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
12 changes: 7 additions & 5 deletions medcat/stats/kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum, auto
from copy import deepcopy
from pydantic import BaseModel
from itertools import islice

import numpy as np

Expand Down Expand Up @@ -205,19 +206,20 @@ def _add_target_ann(self, project: MedCATTrainerExportProject,
cur_doc: MedCATTrainerExportDocument = self._find_or_add_doc(project, orig_doc)
cur_doc['annotations'].append(ann)

def _targets(self) -> Iterable[Tuple[MedCATTrainerExportProjectInfo,
MedCATTrainerExportDocument,
MedCATTrainerExportAnnotation]]:
return iter_anns(self.mct_export)
def _targets(self, start_at: int) -> Iterable[Tuple[MedCATTrainerExportProjectInfo,
MedCATTrainerExportDocument,
MedCATTrainerExportAnnotation]]:
return islice(iter_anns(self.mct_export), start_at, None)

def _create_fold(self, fold_nr: int) -> MedCATTrainerExport:
per_fold = self.per_fold[fold_nr]
already_used = sum(self.per_fold[fn] for fn in range(fold_nr))
cur_fold: MedCATTrainerExport = {
'projects': []
}
cur_project: Optional[MedCATTrainerExportProject] = None
included = 0
for target in self._targets():
for target in self._targets(already_used):
proj_info, cur_doc, cur_ann = target
proj_name = proj_info[0]
if not cur_project or cur_project['name'] != proj_name:
Expand Down
15 changes: 15 additions & 0 deletions tests/stats/test_kfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
from typing import Dict, Union, Optional
from copy import deepcopy
from collections import Counter

from medcat.stats import kfold
from medcat.cat import CAT
Expand Down Expand Up @@ -80,6 +81,20 @@ def test_folds_keep_all_anns(self):
count_all_once = kfold.count_all_annotations(self.mct_export)
self.assertEqual(total_anns, count_all_once)

def count_cuis(self, export: MCTExportTests) -> Counter:
cntr = Counter()
for _, _, ann in kfold.iter_anns(export):
cui = ann["cui"]
cntr[cui] += 1
return cntr

def test_folds_keep_ann_targets(self):
orig_cntr = self.count_cuis(self.mct_export)
fold_counter = Counter()
for fold in self.folds:
fold_counter += self.count_cuis(fold)
self.assertEqual(orig_cntr, fold_counter)

def test_1fold_same_as_orig(self):
folds = kfold.get_fold_creator(self.mct_export, 1, split_type=self.SPLIT_TYPE).create_folds()
self.assertEqual(len(folds), 1)
Expand Down

0 comments on commit 00c0dd0

Please sign in to comment.