Skip to content

Commit

Permalink
Merge pull request #594 from mittagessen/feature/test_wer
Browse files Browse the repository at this point in the history
Add WER calculation to `ketos test` report
  • Loading branch information
mittagessen authored Apr 22, 2024
2 parents 3982aee + a3e22c1 commit fd6a9d0
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 59 deletions.
12 changes: 1 addition & 11 deletions kraken/ketos/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,6 @@
@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.')
@click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False,
help='When loading an existing model, retrieve hyperparameters from the model')
@click.option('--repolygonize/--no-repolygonize', show_default=True,
default=False, help='Repolygonizes line data in ALTO/PageXML '
'files. This ensures that the trained model is compatible with the '
'segmenter in kraken even if the original image files either do '
'not contain anything but transcriptions and baseline information '
'or the polygon data was created using a different method. Will '
'be ignored in `path` mode. Note that this option will be slow '
'and will not scale input images to the same size as the segmenter '
'does.')
@click.option('--force-binarization/--no-binarization', show_default=True,
default=False, help='Forces input images to be binary, otherwise '
'the appropriate color format will be auto-determined through the '
Expand Down Expand Up @@ -188,7 +179,7 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs,
min_epochs, lag, min_delta, device, precision, optimizer, lrate, momentum,
weight_decay, warmup, schedule, gamma, step_size, sched_patience,
cos_max, cos_min_lr, partition, fixed_splits, training_files,
evaluation_files, workers, threads, load_hyper_parameters, repolygonize,
evaluation_files, workers, threads, load_hyper_parameters,
force_binarization, format_type, augment,
mask_probability, mask_width, num_negatives, logit_temp,
ground_truth, legacy_polygons):
Expand Down Expand Up @@ -278,7 +269,6 @@ def pretrain(ctx, batch_size, pad, output, spec, load, freq, quit, epochs,
height=model.height,
width=model.width,
channels=model.channels,
repolygonize=repolygonize,
force_binarization=force_binarization,
format_type=format_type,
legacy_polygons=legacy_polygons,)
Expand Down
61 changes: 29 additions & 32 deletions kraken/ketos/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,6 @@
@click.option('--threads', show_default=True, default=1, type=click.IntRange(1), help='Maximum size of OpenMP/BLAS thread pool.')
@click.option('--load-hyper-parameters/--no-load-hyper-parameters', show_default=True, default=False,
help='When loading an existing model, retrieve hyperparameters from the model')
@click.option('--repolygonize/--no-repolygonize', show_default=True,
default=False, help='Repolygonizes line data in ALTO/PageXML '
'files. This ensures that the trained model is compatible with the '
'segmenter in kraken even if the original image files either do '
'not contain anything but transcriptions and baseline information '
'or the polygon data was created using a different method. Will '
'be ignored in `path` mode. Note that this option will be slow '
'and will not scale input images to the same size as the segmenter '
'does.')
@click.option('--force-binarization/--no-binarization', show_default=True,
default=False, help='Forces input images to be binary, otherwise '
'the appropriate color format will be auto-determined through the '
Expand Down Expand Up @@ -203,7 +194,7 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs,
step_size, sched_patience, cos_max, cos_min_lr, partition,
fixed_splits, normalization, normalize_whitespace, codec, resize,
reorder, base_dir, training_files, evaluation_files, workers,
threads, load_hyper_parameters, repolygonize, force_binarization,
threads, load_hyper_parameters, force_binarization,
format_type, augment, pl_logger, log_dir, ground_truth,
legacy_polygons):
"""
Expand Down Expand Up @@ -305,7 +296,6 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs,
binary_dataset_split=fixed_splits,
num_workers=workers,
load_hyper_parameters=load_hyper_parameters,
repolygonize=repolygonize,
force_binarization=force_binarization,
format_type=format_type,
codec=codec,
Expand Down Expand Up @@ -385,15 +375,6 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs,
default=None, help='Ground truth normalization')
@click.option('-n', '--normalize-whitespace/--no-normalize-whitespace',
show_default=True, default=True, help='Normalizes unicode whitespace')
@click.option('--repolygonize/--no-repolygonize', show_default=True,
default=False, help='Repolygonizes line data in ALTO/PageXML '
'files. This ensures that the trained model is compatible with the '
'segmenter in kraken even if the original image files either do '
'not contain anything but transcriptions and baseline information '
'or the polygon data was created using a different method. Will '
'be ignored in `path` mode. Note, that this option will be slow '
'and will not scale input images to the same size as the segmenter '
'does.')
@click.option('--force-binarization/--no-binarization', show_default=True,
default=False, help='Forces input images to be binary, otherwise '
'the appropriate color format will be auto-determined through the '
Expand All @@ -411,7 +392,7 @@ def train(ctx, batch_size, pad, output, spec, append, load, freq, quit, epochs,
@click.option('--no-legacy-polygons', show_default=True, default=False, is_flag=True, help='Force disable the legacy polygon extractor.')
def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
threads, reorder, base_dir, normalization, normalize_whitespace,
repolygonize, force_binarization, format_type, fixed_splits, test_set, no_legacy_polygons):
force_binarization, format_type, fixed_splits, test_set, no_legacy_polygons):
"""
Evaluate on a test set.
"""
Expand All @@ -421,6 +402,8 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
import numpy as np
from torch.utils.data import DataLoader

from torchmetrics.text import CharErrorRate, WordErrorRate

from kraken.lib import models, util
from kraken.lib.dataset import (ArrowIPCRecognitionDataset,
GroundTruthDataset, ImageInputTransforms,
Expand Down Expand Up @@ -475,24 +458,18 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
dataset_kwargs["split_filter"] = "test"

if format_type in ['xml', 'page', 'alto']:
if repolygonize:
message('Repolygonizing data')
test_set = [{'page': XMLPage(file, filetype=format_type).to_container()} for file in test_set]
valid_norm = False
DatasetClass = partial(PolygonGTDataset, legacy_polygons=legacy_polygons)
elif format_type == 'binary':
DatasetClass = ArrowIPCRecognitionDataset
if repolygonize:
logger.warning('Repolygonization enabled in `binary` mode. Will be ignored.')
test_set = [{'file': file} for file in test_set]
valid_norm = False
else:
DatasetClass = GroundTruthDataset
if force_binarization:
logger.warning('Forced binarization enabled in `path` mode. Will be ignored.')
force_binarization = False
if repolygonize:
logger.warning('Repolygonization enabled in `path` mode. Will be ignored.')
test_set = [{'line': util.parse_gt_path(img)} for img in test_set]
valid_norm = True

Expand All @@ -502,7 +479,8 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
if reorder and base_dir != 'auto':
reorder = base_dir

acc_list = []
cer_list = []
wer_list = []

with threadpool_limits(limits=threads):
for p, net in nn.items():
Expand Down Expand Up @@ -539,6 +517,9 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
pin_memory=pin_ds_mem,
collate_fn=collate_sequences)

test_cer = CharErrorRate()
test_wer = WordErrorRate()

with KrakenProgressBar() as progress:
batches = len(ds_loader)
pred_task = progress.add_task('Evaluating', total=batches, visible=True if not ctx.meta['verbose'] else False)
Expand All @@ -555,6 +536,9 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
algn_gt.extend(algn1)
algn_pred.extend(algn2)
error += c
test_cer.update(x, y)
test_wer.update(x, y)

except FileNotFoundError as e:
batches -= 1
progress.update(pred_task, total=batches)
Expand All @@ -565,10 +549,23 @@ def test(ctx, batch_size, model, evaluation_files, device, pad, workers,
logger.warning(str(e))
progress.update(pred_task, advance=1)

acc_list.append((chars - error) / chars)
cer_list.append(1.0 - test_cer.compute())
wer_list.append(1.0 - test_wer.compute())
confusions, scripts, ins, dels, subs = compute_confusions(algn_gt, algn_pred)
rep = render_report(p, chars, error, confusions, scripts, ins, dels, subs)
rep = render_report(p,
chars,
error,
cer_list[-1],
wer_list[-1],
confusions,
scripts,
ins,
dels,
subs)
logger.info(rep)
message(rep)
logger.info('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100))
message('Average accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(acc_list) * 100, np.std(acc_list) * 100))

logger.info('Average character accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(cer_list) * 100, np.std(cer_list) * 100))
message('Average character accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(cer_list) * 100, np.std(cer_list) * 100))
logger.info('Average word accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(wer_list) * 100, np.std(wer_list) * 100))
message('Average word accuracy: {:0.2f}%, (stddev: {:0.2f})'.format(np.mean(wer_list) * 100, np.std(wer_list) * 100))
7 changes: 0 additions & 7 deletions kraken/lib/pretrain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def __init__(self,
width: int = 0,
channels: int = 1,
num_workers: int = 1,
repolygonize: bool = False,
force_binarization: bool = False,
format_type: str = 'path',
pad: int = 16,
Expand Down Expand Up @@ -125,8 +124,6 @@ def __init__(self,
valid_norm = False
elif format_type == 'binary':
DatasetClass = ArrowIPCRecognitionDataset
if repolygonize:
logger.warning('Repolygonization enabled in `binary` mode. Will be ignored.')
valid_norm = False
logger.info(f'Got {len(training_data)} binary dataset files for training data')
training_data = [{'file': file} for file in training_data]
Expand All @@ -137,8 +134,6 @@ def __init__(self,
if force_binarization:
logger.warning('Forced binarization enabled in `path` mode. Will be ignored.')
force_binarization = False
if repolygonize:
logger.warning('Repolygonization enabled in `path` mode. Will be ignored.')
if binary_dataset_split:
logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.')
binary_dataset_split = False
Expand All @@ -157,8 +152,6 @@ def __init__(self,
if force_binarization:
logger.warning('Forced binarization enabled with box lines. Will be ignored.')
force_binarization = False
if repolygonize:
logger.warning('Repolygonization enabled with box lines. Will be ignored.')
if binary_dataset_split:
logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.')
binary_dataset_split = False
Expand Down
7 changes: 0 additions & 7 deletions kraken/lib/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def __init__(self,
binary_dataset_split: bool = False,
num_workers: int = 1,
load_hyper_parameters: bool = False,
repolygonize: bool = False,
force_binarization: bool = False,
format_type: Literal['path', 'alto', 'page', 'xml', 'binary'] = 'path',
codec: Optional[Dict] = None,
Expand Down Expand Up @@ -291,8 +290,6 @@ def __init__(self,
valid_norm = False
elif format_type == 'binary':
DatasetClass = ArrowIPCRecognitionDataset
if repolygonize:
logger.warning('Repolygonization enabled in `binary` mode. Will be ignored.')
valid_norm = False
logger.info(f'Got {len(training_data)} binary dataset files for training data')
training_data = [{'file': file} for file in training_data]
Expand All @@ -303,8 +300,6 @@ def __init__(self,
if force_binarization:
logger.warning('Forced binarization enabled in `path` mode. Will be ignored.')
force_binarization = False
if repolygonize:
logger.warning('Repolygonization enabled in `path` mode. Will be ignored.')
if binary_dataset_split:
logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.')
binary_dataset_split = False
Expand All @@ -323,8 +318,6 @@ def __init__(self,
if force_binarization:
logger.warning('Forced binarization enabled with box lines. Will be ignored.')
force_binarization = False
if repolygonize:
logger.warning('Repolygonization enabled with box lines. Will be ignored.')
if binary_dataset_split:
logger.warning('Internal binary dataset splits are enabled but using non-binary dataset files. Will be ignored.')
binary_dataset_split = False
Expand Down
5 changes: 4 additions & 1 deletion kraken/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ def _load_template(name):
def render_report(model: str,
chars: int,
errors: int,
char_accuracy: float,
word_accuracy: float,
char_confusions: 'Counter',
scripts: 'Counter',
insertions: 'Counter',
Expand Down Expand Up @@ -275,7 +277,8 @@ def render_report(model: str,
report = {'model': model,
'chars': chars,
'errors': errors,
'accuracy': (chars-errors)/chars * 100,
'character_accuracy': char_accuracy * 100,
'word_accuracy': word_accuracy * 100,
'insertions': sum(insertions.values()),
'deletions': deletions,
'substitutions': sum(substitutions.values()),
Expand Down
3 changes: 2 additions & 1 deletion kraken/templates/report
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

{{ report.chars }} Characters
{{ report.errors }} Errors
{{ '%0.2f'| format(report.accuracy) }}% Accuracy
{{ '%0.2f'| format(report.character_accuracy) }}% Character Accuracy
{{ '%0.2f'| format(report.word_accuracy) }}% Word Accuracy

{{ report.insertions }} Insertions
{{ report.deletions }} Deletions
Expand Down

0 comments on commit fd6a9d0

Please sign in to comment.