Skip to content

Commit

Permalink
docs: add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed Nov 18, 2024
1 parent 8436d1f commit d5fa2c3
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 59 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Sparv plugin to annotate tokens with correction of OCR errors."""

from sbx_ocr_correction_viklofg_sweocr.annotations import annotate_ocr_correction

__all__ = ["annotate_ocr_correction"]

__description__ = "Annotate words with corrections of OCR-errors."
__description__ = "Annotate tokens with corrections of OCR-errors."

__version__ = "0.3.0"
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List, Optional, Tuple
"""Annotatinos for Sparv."""

from typing import Optional

from sparv import api as sparv_api # type: ignore [import-untyped]
from sparv.api import Annotation, Output, annotator # type: ignore [import-untyped]
Expand All @@ -10,27 +12,20 @@

@annotator("OCR corrections as annotations", language=["swe"])
def annotate_ocr_correction(
out_ocr: Output = Output(
"sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction", cls="ocr_correction"
),
out_ocr: Output = Output("sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction", cls="ocr_correction"),
out_ocr_corr: Output = Output(
"sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr",
cls="ocr_correction:correction",
),
# out_ocr_correction: Output = Output(
# "<token>:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr",
# cls="sbx_ocr_correction_viklofg_sweocr",
# description="OCR Corrections from viklfog/swedish-ocr (format: '|<word>:<score>|...|)", # noqa: E501
# ),
word: Annotation = Annotation("<token:word>"),
sentence: Annotation = Annotation("<sentence>"),
token: Annotation = Annotation("<token>"),
) -> None:
"""Sparv annotator to compute OCR corrections as annotations."""
ocr_corrector = OcrCorrector.default()

sentences, _orphans = sentence.get_children(word)
token_word = list(word.read())
# out_ocr_correction_annotation = word.create_empty_attribute()

ocr_corrections = []

Expand All @@ -40,22 +35,18 @@ def annotate_ocr_correction(
sent = [token_word[token_index] for token_index in sent_idx]

ocr_corrections.append(ocr_corrector.calculate_corrections(sent))
# for i, ocr_correction in enumerate(ocr_corrections, start=sent[0]):
# out_ocr_correction_annotation[i] = ocr_correction

# logger.info("writing annotations")
# out_ocr.write(ocr_spans)
# out_ocr_corr.write(ocr_corr_ann)
parse_ocr_corrections(sentences, token, ocr_corrections, out_ocr, out_ocr_corr)


def parse_ocr_corrections(
sentences: List,
sentences: list,
token: Annotation,
ocr_corrections: List[List[Tuple[Tuple[int, int], Optional[str]]]],
ocr_corrections: list[list[tuple[tuple[int, int], Optional[str]]]],
out_ocr: Output,
out_ocr_corr: Output,
) -> None:
"""Parse OCR corrections and write output."""
ocr_spans = []
ocr_corr_ann = []

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""OCR corrector."""

import re
from typing import List, Optional, Tuple
from typing import Any, Optional

from parallel_corpus import graph
from parallel_corpus.token import Token
Expand All @@ -12,6 +14,7 @@


def bytes_length(s: str) -> int:
"""Compute the length in bytes of a str."""
return len(s.encode("utf-8"))


Expand All @@ -26,34 +29,31 @@ def bytes_length(s: str) -> int:


class OcrCorrector:
"""OCR Corrector."""

TEXT_LIMIT: int = 127

def __init__(self, *, tokenizer, model) -> None:
def __init__(self, *, tokenizer: Any, model: Any) -> None:
"""Construct an OcrCorrector."""
self.tokenizer = tokenizer
self.model = model
self.pipeline = pipeline(
"text2text-generation", model=model, tokenizer=tokenizer
)
self.pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)

@classmethod
def default(cls) -> "OcrCorrector":
tokenizer = AutoTokenizer.from_pretrained(
TOKENIZER_NAME, revision=TOKENIZER_REVISION
)
model = T5ForConditionalGeneration.from_pretrained(
MODEL_NAME, revision=MODEL_REVISION
)
"""Create a default OcrCorrector."""
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, revision=TOKENIZER_REVISION)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, revision=MODEL_REVISION)
return cls(model=model, tokenizer=tokenizer)

def calculate_corrections(
self, text: List[str]
) -> List[Tuple[Tuple[int, int], Optional[str]]]:
def calculate_corrections(self, text: list[str]) -> list[tuple[tuple[int, int], Optional[str]]]:
"""Calculate OCR corrections for a given text."""
logger.debug("Analyzing '%s'", text)

parts: List[str] = []
curr_part: List[str] = []
parts: list[str] = []
curr_part: list[str] = []
curr_len = 0
ocr_corrections: List[Tuple[Tuple[int, int], Optional[str]]] = []
ocr_corrections: list[tuple[tuple[int, int], Optional[str]]] = []
for word in text:
len_word = bytes_length(word)
if (curr_len + len_word + 1) > self.TEXT_LIMIT:
Expand All @@ -69,16 +69,14 @@ def calculate_corrections(
suggested_text = self.pipeline(part)[0]["generated_text"]
suggested_text = PUNCTUATION.sub(r" \g<0>", suggested_text)
graph_aligned = graph.init_with_source_and_target(part, suggested_text)
span_ann, curr_start = align_and_diff(graph_aligned, curr_start=curr_start)
span_ann, curr_start = _align_and_diff(graph_aligned, curr_start=curr_start)
ocr_corrections.extend(span_ann)

logger.debug("Finished analyzing. ocr_corrections=%s", ocr_corrections)
return ocr_corrections


def align_and_diff(
g: graph.Graph, *, curr_start: int
) -> Tuple[List[Tuple[Tuple[int, int], Optional[str]]], int]:
def _align_and_diff(g: graph.Graph, *, curr_start: int) -> tuple[list[tuple[tuple[int, int], Optional[str]]], int]:
corrections = []

edge_map = graph.edge_map(g)
Expand All @@ -96,12 +94,8 @@ def align_and_diff(
logger.debug("processing s_token=%s", s_token)

if len(source_ids) == len(target_ids):
source_text = "".join(
lookup_text(g.source, s_id) for s_id in source_ids
).strip()
target_text = "".join(
lookup_text(g.target, s_id) for s_id in target_ids
).strip()
source_text = "".join(lookup_text(g.source, s_id) for s_id in source_ids).strip()
target_text = "".join(lookup_text(g.target, s_id) for s_id in target_ids).strip()
start = curr_start
curr_start += 1
corrections.append(
Expand All @@ -112,9 +106,7 @@ def align_and_diff(
)

elif len(source_ids) == 1:
target_texts = " ".join(
lookup_text(g.target, id_).strip() for id_ in target_ids
)
target_texts = " ".join(lookup_text(g.target, id_).strip() for id_ in target_ids)
source_text = s_token.text.strip()
start = curr_start
curr_start += 1
Expand All @@ -132,14 +124,13 @@ def align_and_diff(
corrections.append(((start, curr_start), target_text))
else:
# TODO Handle this correct (https://github.com/spraakbanken/sparv-sbx-ocr-correction/issues/50)
raise NotImplementedError(
f"Handle several sources, {source_ids=} {target_ids=} {g.source=} {g.target=}" # noqa: E501
)
raise NotImplementedError(f"Handle several sources, {source_ids=} {target_ids=} {g.source=} {g.target=}")

return corrections, curr_start


def lookup_text(tokens: List[Token], id_: str) -> str:
def lookup_text(tokens: list[Token], id_: str) -> str:
"""Lookup text from a token with id `id_`."""
for token in tokens:
if token.id == id_:
return token.text
Expand Down
6 changes: 2 additions & 4 deletions ocr-correction-viklofg-sweocr/tests/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
from sparv_pipeline_testing import MemoryOutput, MockAnnotation


def test_annotate_ocr_correction(snapshot) -> None:
def test_annotate_ocr_correction(snapshot) -> None: # noqa: ANN001
output_ocr: MemoryOutput = MemoryOutput()
output_ocr_corr: MemoryOutput = MemoryOutput()
# "Jonath an saknades ."
# "12345678901234567890"
# " 1 2"
word = MockAnnotation(name="<token:word>", values=["Jonath", "an", "saknades", "."])
sentence = MockAnnotation(
name="<sentence>", children={"<token:word>": [[0, 1, 2, 3]]}
)
sentence = MockAnnotation(name="<sentence>", children={"<token:word>": [[0, 1, 2, 3]]})
token = MockAnnotation(name="<token>", spans=[(0, 6), (7, 9), (10, 18), (19, 20)])

annotate_ocr_correction(output_ocr, output_ocr_corr, word, sentence, token)
Expand Down
6 changes: 3 additions & 3 deletions ocr-correction-viklofg-sweocr/tests/test_ocr_suggestor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sbx_ocr_correction_viklofg_sweocr.ocr_corrector import OcrCorrector


def test_short_text(ocr_corrector: OcrCorrector, snapshot):
def test_short_text(ocr_corrector: OcrCorrector, snapshot) -> None: # noqa: ANN001
text = [
"Den",
"i",
Expand All @@ -21,7 +21,7 @@ def test_short_text(ocr_corrector: OcrCorrector, snapshot):
assert actual == snapshot


def test_long_text(ocr_corrector: OcrCorrector, snapshot):
def test_long_text(ocr_corrector: OcrCorrector, snapshot) -> None: # noqa: ANN001
text1 = [
"Förvaltningen",
"af",
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_long_text(ocr_corrector: OcrCorrector, snapshot):
assert actual == snapshot


def test_issue_40(ocr_corrector: OcrCorrector, snapshot) -> None:
def test_issue_40(ocr_corrector: OcrCorrector, snapshot) -> None: # noqa: ANN001
example = [
"Jonathan",
"saknades",
Expand Down

0 comments on commit d5fa2c3

Please sign in to comment.