Skip to content

Commit

Permalink
fix: handle annotations when 2 tokens are merged
Browse files Browse the repository at this point in the history
Change to use output spans and corrections.

Fixes #44
  • Loading branch information
kod-kristoff committed Nov 7, 2024
1 parent b73f93d commit 21944ac
Show file tree
Hide file tree
Showing 11 changed files with 626 additions and 125 deletions.
21 changes: 11 additions & 10 deletions examples/christoph-borg/config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
metadata:
id: christoph-borg
language: swe
id: christoph-borg
language: swe

import:
importer: text_import:parse
importer: text_import:parse

export:
annotations:
- <sentence>
# - <token:word>
- <token>:stanza.pos
- <token>:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr

annotations:
- <sentence>
# - <token:word>
- <token>:stanza.pos
- sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction
- sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr

sparv:
compression: none
compression: none
19 changes: 10 additions & 9 deletions examples/ocr-correction-viklofg-sweocr/config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
metadata:
id: hello-ocr
language: swe
id: hello-ocr
language: swe

import:
importer: text_import:parse
importer: text_import:parse

export:
annotations:
- <sentence>
# - <token:word>
- <token>:stanza.pos
- <token>:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr
annotations:
- <sentence>
# - <token:word>
- <token>:stanza.pos
- sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction
- sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr

sparv:
compression: none
compression: none
3 changes: 2 additions & 1 deletion examples/texts/dokument.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
Den i HandelstidniDgens g&rdagsnnmmer omtalade hvalfisken, sorn fångats i Frölnndaviken
Den i HandelstidniDgens g&rdagsnnmmer omtalade hvalfisken, sorn fångats i Frölnndaviken.
Jonath an saknades.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List, Optional, Tuple

from sparv import api as sparv_api # type: ignore [import-untyped]
from sparv.api import Annotation, Output, annotator # type: ignore [import-untyped]

Expand All @@ -8,28 +10,74 @@

@annotator("OCR corrections as annotations", language=["swe"])
def annotate_ocr_correction(
out_ocr_correction: Output = Output(
"<token>:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr",
cls="sbx_ocr_correction_viklofg_sweocr",
description="OCR Corrections from viklfog/swedish-ocr (format: '|<word>:<score>|...|)", # noqa: E501
out_ocr: Output = Output(
"sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction", cls="ocr_correction"
),
out_ocr_corr: Output = Output(
"sbx_ocr_correction_viklofg_sweocr.sbx-ocr-correction:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr",
cls="ocr_correction:correction",
),
# out_ocr_correction: Output = Output(
# "<token>:sbx_ocr_correction_viklofg_sweocr.ocr-correction--viklofg-sweocr",
# cls="sbx_ocr_correction_viklofg_sweocr",
# description="OCR Corrections from viklfog/swedish-ocr (format: '|<word>:<score>|...|)", # noqa: E501
# ),
word: Annotation = Annotation("<token:word>"),
sentence: Annotation = Annotation("<sentence>"),
token: Annotation = Annotation("<token>"),
) -> None:
ocr_corrector = OcrCorrector.default()

sentences, _orphans = sentence.get_children(word)
token_word = list(word.read())
out_ocr_correction_annotation = word.create_empty_attribute()
# out_ocr_correction_annotation = word.create_empty_attribute()

ocr_corrections = []

logger.progress(total=len(sentences)) # type: ignore
for sent in sentences:
for sent_idx in sentences:
logger.progress() # type: ignore
sent_to_tag = [token_word[token_index] for token_index in sent]
sent = [token_word[token_index] for token_index in sent_idx]

ocr_corrections.append(ocr_corrector.calculate_corrections(sent))
# for i, ocr_correction in enumerate(ocr_corrections, start=sent[0]):
# out_ocr_correction_annotation[i] = ocr_correction

# logger.info("writing annotations")
# out_ocr.write(ocr_spans)
# out_ocr_corr.write(ocr_corr_ann)
parse_ocr_corrections(sentences, token, ocr_corrections, out_ocr, out_ocr_corr)


def parse_ocr_corrections(
sentences: List,
token: Annotation,
ocr_corrections: List[List[Tuple[Tuple[int, int], Optional[str]]]],
out_ocr: Output,
out_ocr_corr: Output,
) -> None:
ocr_spans = []
ocr_corr_ann = []

token_spans = list(token.read_spans())
for sent, corr_sent in zip(sentences, ocr_corrections):
i = 0
for span, corr_opt in corr_sent:
start_pos = token_spans[sent[i]][0]

i += span[1] - span[0]

ocr_corrections = ocr_corrector.calculate_corrections(sent_to_tag)
for i, ocr_correction in enumerate(ocr_corrections, start=sent[0]):
out_ocr_correction_annotation[i] = ocr_correction
end_pos = token_spans[sent[i - 1]][1]
logger.debug(
"(%d, %d): '%s'",
start_pos,
end_pos,
"" if corr_opt is None else corr_opt,
)
if corr_opt is not None:
ocr_spans.append((start_pos, end_pos))
ocr_corr_ann.append(corr_opt)

logger.info("writing annotations")
out_ocr_correction.write(out_ocr_correction_annotation)
out_ocr.write(ocr_spans)
out_ocr_corr.write(ocr_corr_ann)
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import re
from typing import List, Optional
from typing import List, Optional, Tuple

from parallel_corpus import graph
from parallel_corpus.token import Token
from sparv import api as sparv_api # type: ignore [import-untyped]
from transformers import ( # type: ignore [import-untyped]
AutoTokenizer,
Expand Down Expand Up @@ -44,13 +45,15 @@ def default(cls) -> "OcrCorrector":
)
return cls(model=model, tokenizer=tokenizer)

def calculate_corrections(self, text: List[str]) -> List[Optional[str]]:
def calculate_corrections(
self, text: List[str]
) -> List[Tuple[Tuple[int, int], Optional[str]]]:
logger.debug("Analyzing '%s'", text)

parts: List[str] = []
curr_part: List[str] = []
curr_len = 0
ocr_corrections: List[Optional[str]] = []
ocr_corrections: List[Tuple[Tuple[int, int], Optional[str]]] = []
for word in text:
len_word = bytes_length(word)
if (curr_len + len_word + 1) > self.TEXT_LIMIT:
Expand All @@ -61,66 +64,85 @@ def calculate_corrections(self, text: List[str]) -> List[Optional[str]]:
curr_len = len_word if curr_len == 0 else curr_len + len_word + 1
if len(curr_part) > 0:
parts.append(TOK_SEP.join(curr_part))
curr_start = 0
for part in parts:
graph_initial = graph.init(part)
suggested_text = self.pipeline(part)[0]["generated_text"]

suggested_text = PUNCTUATION.sub(r" \0", suggested_text)
graph_aligned = graph.set_target(graph_initial, suggested_text)
ocr_corrections.extend(align_and_diff(graph_aligned))
suggested_text = PUNCTUATION.sub(r" \g<0>", suggested_text)
graph_aligned = graph.init_with_source_and_target(part, suggested_text)
span_ann, curr_start = align_and_diff(graph_aligned, curr_start=curr_start)
ocr_corrections.extend(span_ann)

logger.debug("Finished analyzing. ocr_corrections=%s", ocr_corrections)
return ocr_corrections


def align_and_diff(g: graph.Graph) -> List[Optional[str]]:
def align_and_diff(
g: graph.Graph, *, curr_start: int
) -> Tuple[List[Tuple[Tuple[int, int], Optional[str]]], int]:
corrections = []
edge_map = graph.edge_map(g)
visited_tokens = set()
for s_token in g.source:
logger.debug("checking s_token=%s", s_token)
edge = edge_map[s_token.id]

source_ids = [id_ for id_ in edge.ids if id_.startswith("s")]
target_ids = [id_ for id_ in edge.ids if id_.startswith("t")]
target_ids_str = "-".join(target_ids)
if target_ids_str in visited_tokens:
continue
visited_tokens.add(target_ids_str)
logger.debug("processing s_token=%s", s_token)

if len(source_ids) == len(target_ids):
source_text = "".join(
lookup_text(g, s_id, graph.Side.source) for s_id in source_ids
lookup_text(g.source, s_id) for s_id in source_ids
).strip()
target_text = "".join(
lookup_text(g, s_id, graph.Side.target) for s_id in target_ids
lookup_text(g.target, s_id) for s_id in target_ids
).strip()
corrections.append(target_text if source_text != target_text else None)
start = curr_start
curr_start += 1
corrections.append(
(
(start, curr_start),
target_text if source_text != target_text else None,
)
)

elif len(source_ids) == 1:
target_texts = " ".join(
lookup_text(g, id_, graph.Side.target).strip() for id_ in target_ids
lookup_text(g.target, id_).strip() for id_ in target_ids
)
source_text = s_token.text.strip()
corrections.append(target_texts if source_text != target_texts else None)
elif len(target_ids) == 1:
# TODO Handle this correct (https://github.com/spraakbanken/sparv-sbx-ocr-correction/issues/44)
logger.warn(
f"Handle several sources, see https://github.com/spraakbanken/sparv-sbx-ocr-correction/issues/44, {source_ids=} {target_ids=} {g.source=} {g.target=}" # noqa: E501
start = curr_start
curr_start += 1

corrections.append(
(
(start, curr_start),
target_texts if source_text != target_texts else None,
),
)
target_text = lookup_text(g, target_ids[0], graph.Side.target).strip()
corrections.append(target_text)
elif len(target_ids) == 1:
target_text = lookup_text(g.target, target_ids[0]).strip()
start = curr_start
curr_start += len(source_ids)
corrections.append(((start, curr_start), target_text))
else:
# TODO Handle this correct (https://github.com/spraakbanken/sparv-sbx-ocr-correction/issues/44)
# TODO Handle this correct (https://github.com/spraakbanken/sparv-sbx-ocr-correction/issues/50)
raise NotImplementedError(
f"Handle several sources, {source_ids=} {target_ids=} {g.source=} {g.target=}" # noqa: E501
)

return corrections
return corrections, curr_start


def lookup_text(tokens: List[Token], id_: str) -> str:
for token in tokens:
if token.id == id_:
return token.text

def lookup_text(g: graph.Graph, id_: str, side: graph.Side) -> str:
if side == graph.Side.source:
for token in g.source:
if token.id == id_:
return token.text
else:
for token in g.target:
if token.id == id_:
return token.text
raise ValueError(
f"The id={id_} isn't found in the given graph on side={side}",
f"The id={id_} isn't found in the list of tokens",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# serializer version: 1
# name: test_annotate_ocr_correction
list([
tuple(
7,
18,
),
])
# ---
# name: test_annotate_ocr_correction.1
list([
'ansaknades',
])
# ---
Loading

0 comments on commit 21944ac

Please sign in to comment.