Skip to content

Commit

Permalink
fix: align text with corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed May 8, 2024
1 parent df8ff1c commit 1a87220
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
from typing import List, Optional

from parallel_corpus import graph
from sparv import api as sparv_api # type: ignore [import-untyped]
from transformers import ( # type: ignore [import-untyped]
AutoTokenizer,
Expand All @@ -12,17 +14,15 @@ def bytes_length(s: str) -> int:
return len(s.encode("utf-8"))


def zip_and_diff(orig: List[str], sugg: List[str]) -> List[Optional[str]]:
return [sw if sw != ow else None for (ow, sw) in zip(orig, sugg)]


TOK_SEP = " "
logger = sparv_api.get_logger(__name__)
TOKENIZER_REVISION = "68377bdc18a2ffec8a0533fef03b1c513a4dd49d"
TOKENIZER_NAME = "google/byt5-small"
MODEL_REVISION = "84b138048992271be7617ccb11056bbcb9b72262"
MODEL_NAME = "viklofg/swedish-ocr-correction"

PUNCTUATION = re.compile(r"[.,:;!?]")


class OcrCorrector:
TEXT_LIMIT: int = 127
Expand All @@ -46,10 +46,11 @@ def default(cls) -> "OcrCorrector":

def calculate_corrections(self, text: List[str]) -> List[Optional[str]]:
logger.debug("Analyzing '%s'", text)
parts = []

parts: List[str] = []
curr_part: List[str] = []
curr_len = 0
ocr_corrections: List[str] = []
ocr_corrections: List[Optional[str]] = []
for word in text:
len_word = bytes_length(word)
if (curr_len + len_word + 1) > self.TEXT_LIMIT:
Expand All @@ -61,11 +62,65 @@ def calculate_corrections(self, text: List[str]) -> List[Optional[str]]:
if len(curr_part) > 0:
parts.append(TOK_SEP.join(curr_part))
for part in parts:
graph_initial = graph.init(part)
suggested_text = self.pipeline(part)[0]["generated_text"]
suggested_text = suggested_text.replace(",", " ,")
suggested_text = suggested_text.replace(".", " .")
ocr_corrections = ocr_corrections + suggested_text.split(TOK_SEP)

if len(text) == len(ocr_corrections) + 1 and text[-1] != ocr_corrections[-1]:
ocr_corrections.append(text[-1])
return zip_and_diff(text, ocr_corrections)
suggested_text = PUNCTUATION.sub(r" \0", suggested_text)
graph_aligned = graph.set_target(graph_initial, suggested_text)
ocr_corrections.extend(align_and_diff(graph_aligned))

logger.debug("Finished analyzing. ocr_corrections=%s", ocr_corrections)
return ocr_corrections


def align_and_diff(g: graph.Graph) -> List[Optional[str]]:
corrections = []
edge_map = graph.edge_map(g)
for s_token in g.source:
edge = edge_map[s_token.id]

source_ids = [id_ for id_ in edge.ids if id_.startswith("s")]
target_ids = [id_ for id_ in edge.ids if id_.startswith("t")]
if len(source_ids) == len(target_ids):
source_text = "".join(
lookup_text(g, s_id, graph.Side.source) for s_id in source_ids
).strip()
target_text = "".join(
lookup_text(g, s_id, graph.Side.target) for s_id in target_ids
).strip()
corrections.append(target_text if source_text != target_text else None)

elif len(source_ids) == 1:
target_texts = " ".join(
lookup_text(g, id_, graph.Side.target).strip() for id_ in target_ids
)
source_text = s_token.text.strip()
corrections.append(target_texts if source_text != target_texts else None)
elif len(target_ids) == 1:
# TODO Handle this correct (https://github.com/spraakbanken/sparv-sbx-ocr-correction/issues/44)
logger.warn(
f"Handle several sources, see https://github.com/spraakbanken/sparv-sbx-ocr-correction/issues/44, {source_ids=} {target_ids=} {g.source=} {g.target=}" # noqa: E501
)
target_text = lookup_text(g, target_ids[0], graph.Side.target).strip()
corrections.append(target_text)
else:
# TODO Handle this correct (https://github.com/spraakbanken/sparv-sbx-ocr-correction/issues/44)
raise NotImplementedError(
f"Handle several sources, {source_ids=} {target_ids=} {g.source=} {g.target=}" # noqa: E501
)

return corrections


def lookup_text(g: graph.Graph, id_: str, side: graph.Side) -> str:
if side == graph.Side.source:
for token in g.source:
if token.id == id_:
return token.text
else:
for token in g.target:
if token.id == id_:
return token.text
raise ValueError(
f"The id={id_} isn't found in the given graph on side={side}",
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
# serializer version: 1
# name: test_issue_40
<class 'list'> [
'Jonat+han',
list([
'Jonat han',
None,
'',
None,
None,
'',
None,
None,
None,
'',
None,
None,
None,
Expand All @@ -15,17 +19,15 @@
None,
None,
None,
None,
None,
None,
None,
]
---
'',
])
# ---
# name: test_long_text
<class 'list'> [
list([
None,
None,
'Riksgäldskontoret',
'',
None,
None,
None,
Expand All @@ -37,8 +39,7 @@
None,
None,
None,
None,
None,
'',
None,
None,
None,
Expand All @@ -51,22 +52,22 @@
None,
'Riksdagsordningen',
None,
None,
]
---
'',
])
# ---
# name: test_short_text
<class 'list'> [
list([
None,
None,
'Handelstidningens',
'gårdagsnummer',
None,
None,
None,
'',
'som',
None,
None,
'Frölandsviken',
None,
]
---
'',
])
# ---

0 comments on commit 1a87220

Please sign in to comment.