diff --git a/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/__init__.py b/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/__init__.py index 4e7b87c..e63b139 100644 --- a/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/__init__.py +++ b/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/__init__.py @@ -1,7 +1,9 @@ +"""Sparv plugin to annotate tokens with correction of OCR errors.""" + from sbx_ocr_correction_viklofg_sweocr.annotations import annotate_ocr_correction __all__ = ["annotate_ocr_correction"] -__description__ = "Annotate words with corrections of OCR-errors." +__description__ = "Annotate tokens with corrections of OCR-errors." __version__ = "0.3.0" diff --git a/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/annotations.py b/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/annotations.py index 304a13b..4b016eb 100644 --- a/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/annotations.py +++ b/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/annotations.py @@ -1,4 +1,6 @@ -from typing import List, Optional, Tuple +"""Annotatinos for Sparv.""" + +from typing import Optional from sparv import api as sparv_api # type: ignore [import-untyped] from sparv.api import Annotation, Output, annotator # type: ignore [import-untyped] @@ -10,27 +12,20 @@ @annotator("OCR corrections as annotations", language=["swe"]) def annotate_ocr_correction( - out_ocr: Output = Output( - "sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction", cls="ocr_correction" - ), + out_ocr: Output = Output("sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction", cls="ocr_correction"), out_ocr_corr: Output = Output( "sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr", cls="ocr_correction:correction", ), - # out_ocr_correction: Output = Output( - # ":sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr", - # cls="sbx_ocr_correction_viklofg_sweocr", - # description="OCR Corrections from viklfog/swedish-ocr (format: '|:|...|)", # noqa: E501 - # ), word: Annotation = Annotation(""), sentence: Annotation = Annotation(""), token: Annotation = Annotation(""), ) -> None: + """Sparv annotator to compute OCR corrections as annotations.""" ocr_corrector = OcrCorrector.default() sentences, _orphans = sentence.get_children(word) token_word = list(word.read()) - # out_ocr_correction_annotation = word.create_empty_attribute() ocr_corrections = [] @@ -40,22 +35,18 @@ def annotate_ocr_correction( sent = [token_word[token_index] for token_index in sent_idx] ocr_corrections.append(ocr_corrector.calculate_corrections(sent)) - # for i, ocr_correction in enumerate(ocr_corrections, start=sent[0]): - # out_ocr_correction_annotation[i] = ocr_correction - # logger.info("writing annotations") - # out_ocr.write(ocr_spans) - # out_ocr_corr.write(ocr_corr_ann) parse_ocr_corrections(sentences, token, ocr_corrections, out_ocr, out_ocr_corr) def parse_ocr_corrections( - sentences: List, + sentences: list, token: Annotation, - ocr_corrections: List[List[Tuple[Tuple[int, int], Optional[str]]]], + ocr_corrections: list[list[tuple[tuple[int, int], Optional[str]]]], out_ocr: Output, out_ocr_corr: Output, ) -> None: + """Parse OCR corrections and write output.""" ocr_spans = [] ocr_corr_ann = [] diff --git a/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/ocr_corrector.py b/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/ocr_corrector.py index 0347442..9ff28ba 100644 --- a/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/ocr_corrector.py +++ b/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/ocr_corrector.py @@ -1,5 +1,7 @@ +"""OCR corrector.""" + import re -from typing import List, Optional, Tuple +from typing import Any, Optional from parallel_corpus import graph from parallel_corpus.token import Token @@ -12,6 +14,7 @@ def bytes_length(s: str) -> int: + """Compute the length in bytes of a str.""" return len(s.encode("utf-8")) @@ -26,34 +29,31 @@ def bytes_length(s: str) -> int: class OcrCorrector: + """OCR Corrector.""" + TEXT_LIMIT: int = 127 - def __init__(self, *, tokenizer, model) -> None: + def __init__(self, *, tokenizer: Any, model: Any) -> None: + """Construct an OcrCorrector.""" self.tokenizer = tokenizer self.model = model - self.pipeline = pipeline( - "text2text-generation", model=model, tokenizer=tokenizer - ) + self.pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer) @classmethod def default(cls) -> "OcrCorrector": - tokenizer = AutoTokenizer.from_pretrained( - TOKENIZER_NAME, revision=TOKENIZER_REVISION - ) - model = T5ForConditionalGeneration.from_pretrained( - MODEL_NAME, revision=MODEL_REVISION - ) + """Create a default OcrCorrector.""" + tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, revision=TOKENIZER_REVISION) + model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, revision=MODEL_REVISION) return cls(model=model, tokenizer=tokenizer) - def calculate_corrections( - self, text: List[str] - ) -> List[Tuple[Tuple[int, int], Optional[str]]]: + def calculate_corrections(self, text: list[str]) -> list[tuple[tuple[int, int], Optional[str]]]: + """Calculate OCR corrections for a given text.""" logger.debug("Analyzing '%s'", text) - parts: List[str] = [] - curr_part: List[str] = [] + parts: list[str] = [] + curr_part: list[str] = [] curr_len = 0 - ocr_corrections: List[Tuple[Tuple[int, int], Optional[str]]] = [] + ocr_corrections: list[tuple[tuple[int, int], Optional[str]]] = [] for word in text: len_word = bytes_length(word) if (curr_len + len_word + 1) > self.TEXT_LIMIT: @@ -69,16 +69,14 @@ def calculate_corrections( suggested_text = self.pipeline(part)[0]["generated_text"] suggested_text = PUNCTUATION.sub(r" \g<0>", suggested_text) graph_aligned = graph.init_with_source_and_target(part, suggested_text) - span_ann, curr_start = align_and_diff(graph_aligned, curr_start=curr_start) + span_ann, curr_start = _align_and_diff(graph_aligned, curr_start=curr_start) ocr_corrections.extend(span_ann) logger.debug("Finished analyzing. ocr_corrections=%s", ocr_corrections) return ocr_corrections -def align_and_diff( - g: graph.Graph, *, curr_start: int -) -> Tuple[List[Tuple[Tuple[int, int], Optional[str]]], int]: +def _align_and_diff(g: graph.Graph, *, curr_start: int) -> tuple[list[tuple[tuple[int, int], Optional[str]]], int]: corrections = [] edge_map = graph.edge_map(g) @@ -96,12 +94,8 @@ def align_and_diff( logger.debug("processing s_token=%s", s_token) if len(source_ids) == len(target_ids): - source_text = "".join( - lookup_text(g.source, s_id) for s_id in source_ids - ).strip() - target_text = "".join( - lookup_text(g.target, s_id) for s_id in target_ids - ).strip() + source_text = "".join(lookup_text(g.source, s_id) for s_id in source_ids).strip() + target_text = "".join(lookup_text(g.target, s_id) for s_id in target_ids).strip() start = curr_start curr_start += 1 corrections.append( @@ -112,9 +106,7 @@ def align_and_diff( ) elif len(source_ids) == 1: - target_texts = " ".join( - lookup_text(g.target, id_).strip() for id_ in target_ids - ) + target_texts = " ".join(lookup_text(g.target, id_).strip() for id_ in target_ids) source_text = s_token.text.strip() start = curr_start curr_start += 1 @@ -132,14 +124,13 @@ def align_and_diff( corrections.append(((start, curr_start), target_text)) else: # TODO Handle this correct (https://github.com/spraakbanken/sparv-sbx-ocr-correction/issues/50) - raise NotImplementedError( - f"Handle several sources, {source_ids=} {target_ids=} {g.source=} {g.target=}" # noqa: E501 - ) + raise NotImplementedError(f"Handle several sources, {source_ids=} {target_ids=} {g.source=} {g.target=}") return corrections, curr_start -def lookup_text(tokens: List[Token], id_: str) -> str: +def lookup_text(tokens: list[Token], id_: str) -> str: + """Lookup text from a token with id `id_`.""" for token in tokens: if token.id == id_: return token.text diff --git a/ocr-correction-viklofg-sweocr/tests/test_annotations.py b/ocr-correction-viklofg-sweocr/tests/test_annotations.py index 50ca4ea..b498252 100644 --- a/ocr-correction-viklofg-sweocr/tests/test_annotations.py +++ b/ocr-correction-viklofg-sweocr/tests/test_annotations.py @@ -2,16 +2,14 @@ from sparv_pipeline_testing import MemoryOutput, MockAnnotation -def test_annotate_ocr_correction(snapshot) -> None: +def test_annotate_ocr_correction(snapshot) -> None: # noqa: ANN001 output_ocr: MemoryOutput = MemoryOutput() output_ocr_corr: MemoryOutput = MemoryOutput() # "Jonath an saknades ." # "12345678901234567890" # " 1 2" word = MockAnnotation(name="", values=["Jonath", "an", "saknades", "."]) - sentence = MockAnnotation( - name="", children={"": [[0, 1, 2, 3]]} - ) + sentence = MockAnnotation(name="", children={"": [[0, 1, 2, 3]]}) token = MockAnnotation(name="", spans=[(0, 6), (7, 9), (10, 18), (19, 20)]) annotate_ocr_correction(output_ocr, output_ocr_corr, word, sentence, token) diff --git a/ocr-correction-viklofg-sweocr/tests/test_ocr_suggestor.py b/ocr-correction-viklofg-sweocr/tests/test_ocr_suggestor.py index fa90704..b660f07 100644 --- a/ocr-correction-viklofg-sweocr/tests/test_ocr_suggestor.py +++ b/ocr-correction-viklofg-sweocr/tests/test_ocr_suggestor.py @@ -1,7 +1,7 @@ from sbx_ocr_correction_viklofg_sweocr.ocr_corrector import OcrCorrector -def test_short_text(ocr_corrector: OcrCorrector, snapshot): +def test_short_text(ocr_corrector: OcrCorrector, snapshot) -> None: # noqa: ANN001 text = [ "Den", "i", @@ -21,7 +21,7 @@ def test_short_text(ocr_corrector: OcrCorrector, snapshot): assert actual == snapshot -def test_long_text(ocr_corrector: OcrCorrector, snapshot): +def test_long_text(ocr_corrector: OcrCorrector, snapshot) -> None: # noqa: ANN001 text1 = [ "Förvaltningen", "af", @@ -62,7 +62,7 @@ def test_long_text(ocr_corrector: OcrCorrector, snapshot): assert actual == snapshot -def test_issue_40(ocr_corrector: OcrCorrector, snapshot) -> None: +def test_issue_40(ocr_corrector: OcrCorrector, snapshot) -> None: # noqa: ANN001 example = [ "Jonathan", "saknades",