Skip to content

Commit

Permalink
refactor: don't use suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed Feb 23, 2024
1 parent a4ad9ea commit 7210b0a
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![CI](https://github.com/spraakbanken/sparv-ocr-correction-plugin/actions/workflows/ci.yml/badge.svg)](https://github.com/spraakbanken/sparv-ocr-correction-plugin/actions/workflows/ci.yml)
[![PyPI version](https://badge.fury.io/py/sparv-ocr-correction-plugin.svg)](https://pypi.org/project/sparv-ocr-correction-plugin)

Sparv plugin to annotate suggestions to OCR:ed documents.
Sparv plugin to annotate corrections to OCR:ed documents.

## Install

Expand Down
2 changes: 1 addition & 1 deletion examples/hello-ocr/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export:
- <sentence>
# - <token:word>
- <token>:stanza.pos
- <token>:ocr_correction.ocr-correction
- <token>:ocr_correction.ocr-correction--viklofg-swedish-ocr

sparv:
compression: none
28 changes: 8 additions & 20 deletions src/ocr_correction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from sparv.api import ( # type: ignore [import-untyped]
Annotation,
Config,
Output,
annotator,
get_logger,
Expand All @@ -17,18 +16,7 @@

DEFAULT_MODEL_NAME = "viklofg/swedish-ocr-correction"
DEFAULT_TOKENIZER_NAME = "google/byt5-small"
__config__ = [
Config(
"ocr_correction.model",
description="Huggingface pretrained model name",
default=DEFAULT_MODEL_NAME,
),
Config(
"ocr_correction.tokenizer",
description="HuggingFace pretrained tokenizer name",
default=DEFAULT_TOKENIZER_NAME,
),
]


__version__ = "0.1.0"

Expand All @@ -42,18 +30,18 @@
)
def annotate_ocr_correction(
out_ocr_correction: Output = Output(
"<token>:ocr_correction.ocr-correction",
"<token>:ocr_correction.ocr-correction--viklofg-swedish-ocr",
cls="ocr_correction",
description="Neighbours from masked BERT (format: '|<word>:<score>|...|)",
),
word: Annotation = Annotation("<token:word>"),
sentence: Annotation = Annotation("<sentence>"),
model_name: str = Config("ocr_correction.model"),
tokenizer_name: str = Config("ocr_correction.tokenizer"),
) -> None:
tokenizer_name = DEFAULT_TOKENIZER_NAME
model_name = DEFAULT_MODEL_NAME
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
ocr_suggestor = OcrSuggestor(model=model, tokenizer=tokenizer)
ocr_corrector = OcrCorrector(model=model, tokenizer=tokenizer)

sentences, _orphans = sentence.get_children(word)
token_word = list(word.read())
Expand All @@ -64,14 +52,14 @@ def annotate_ocr_correction(
logger.progress() # type: ignore
sent_to_tag = [token_word[token_index] for token_index in sent]

ocr_corrections = ocr_suggestor.calculate_suggestions(sent_to_tag)
ocr_corrections = ocr_corrector.calculate_corrections(sent_to_tag)
out_ocr_correction_annotation[:] = ocr_corrections

logger.info("writing annotations")
out_ocr_correction.write(out_ocr_correction_annotation)


class OcrSuggestor:
class OcrCorrector:
TEXT_LIMIT: int = 127

def __init__(self, *, tokenizer, model) -> None:
Expand All @@ -81,7 +69,7 @@ def __init__(self, *, tokenizer, model) -> None:
"text2text-generation", model=model, tokenizer=tokenizer
)

def calculate_suggestions(self, text: list[str]) -> list[Optional[str]]:
def calculate_corrections(self, text: list[str]) -> list[Optional[str]]:
logger.debug("Analyzing '%s'", text)
parts = []
curr_part: list[str] = []
Expand Down
6 changes: 3 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from ocr_correction import (
DEFAULT_MODEL_NAME,
DEFAULT_TOKENIZER_NAME,
OcrSuggestor,
OcrCorrector,
)
from transformers import ( # type: ignore [import-untyped]
AutoTokenizer,
Expand All @@ -11,7 +11,7 @@


@pytest.fixture(scope="session")
def ocr_suggestor() -> OcrSuggestor:
def ocr_corrector() -> OcrCorrector:
tokenizer = AutoTokenizer.from_pretrained(DEFAULT_TOKENIZER_NAME)
model = T5ForConditionalGeneration.from_pretrained(DEFAULT_MODEL_NAME)
return OcrSuggestor(model=model, tokenizer=tokenizer)
return OcrCorrector(model=model, tokenizer=tokenizer)
12 changes: 6 additions & 6 deletions tests/test_ocr_suggestor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ocr_correction import OcrSuggestor
from ocr_correction import OcrCorrector


def test_short_text(ocr_suggestor: OcrSuggestor):
def test_short_text(ocr_corrector: OcrCorrector):
text = [
"Den",
"i",
Expand All @@ -16,7 +16,7 @@ def test_short_text(ocr_suggestor: OcrSuggestor):
"Frölnndaviken",
".",
]
actual = ocr_suggestor.calculate_suggestions(text)
actual = ocr_corrector.calculate_corrections(text)

expected = [
None,
Expand All @@ -35,7 +35,7 @@ def test_short_text(ocr_suggestor: OcrSuggestor):
assert actual == expected


def test_long_text(ocr_suggestor: OcrSuggestor):
def test_long_text(ocr_corrector: OcrCorrector):
text1 = [
"Förvaltningen",
"af",
Expand Down Expand Up @@ -71,8 +71,8 @@ def test_long_text(ocr_suggestor: OcrSuggestor):
# blifvit dertill utsedd; tillkommande Fullmäktige att sjelfva bland sig välja en
# vice Ordförande att föra ordet, när hinder för Ordföranden inträffar."""
# print(f"{len(text2)=}, {len(text2.encode())=}")
actual = ocr_suggestor.calculate_suggestions(text1)
# actual = ocr_suggestor.calculate_suggestions(text2)
actual = ocr_corrector.calculate_corrections(text1)
# actual = ocr_corrector.calculate_corrections(text2)

expected = [
None,
Expand Down

0 comments on commit 7210b0a

Please sign in to comment.