Skip to content

Commit

Permalink
build: rename project to sparv-ocr-correction-plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
kod-kristoff committed Feb 22, 2024
1 parent 7a5f64d commit a4ad9ea
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 34 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,10 @@ jobs:
name: pypi_files
path: dist

- run: rm -r src/ocr_suggestion
- run: rm -r src/ocr_correction
- run: pip install typing-extensions
- run: pip install -r tests/requirements-testing.txt
- run: pip install sparv-ocr-suggestion-plugin --no-index --no-deps --find-links dist --force-reinstall
- run: pip install sparv-ocr-correction-plugin --no-index --no-deps --find-links dist --force-reinstall
- run: pytest

publish:
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ help:
@echo ""

PLATFORM := `uname -o`
REPO := "sparv-ocr-suggestion-plugin"
PROJECT_SRC := "src/ocr_suggestion"
REPO := "sparv-ocr-correction-plugin"
PROJECT_SRC := "src/ocr_correction"

ifeq (${VIRTUAL_ENV},)
VENV_NAME = .venv
Expand Down
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# sparv-ocr-suggestion-plugin
# sparv-ocr-correction-plugin

[![CI](https://github.com/spraakbanken/sparv-ocr-suggestion-plugin/actions/workflows/ci.yml/badge.svg)](https://github.com/spraakbanken/sparv-ocr-suggestion-plugin/actions/workflows/ci.yml)
[![PyPI version](https://badge.fury.io/py/sparv-ocr-suggestion-plugin.svg)](https://pypi.org/project/sparv-ocr-suggestion-plugin)
[![CI](https://github.com/spraakbanken/sparv-ocr-correction-plugin/actions/workflows/ci.yml/badge.svg)](https://github.com/spraakbanken/sparv-ocr-correction-plugin/actions/workflows/ci.yml)
[![PyPI version](https://badge.fury.io/py/sparv-ocr-correction-plugin.svg)](https://pypi.org/project/sparv-ocr-correction-plugin)

Sparv plugin to annotate suggestions to OCR:ed documents.

Expand All @@ -10,11 +10,11 @@ Sparv plugin to annotate suggestions to OCR:ed documents.
In a virtual environment:

```bash
pip install sparv-ocr-suggestion-plugin
pip install sparv-ocr-correction-plugin
```

or if you have `sparv` installed with `pipx`:

```bash
pipx inject sparv-pipeline sparv-ocr-suggestion-plugin
pipx inject sparv-pipeline sparv-ocr-correction-plugin
```
2 changes: 1 addition & 1 deletion examples/hello-ocr/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export:
- <sentence>
# - <token:word>
- <token>:stanza.pos
- <token>:ocr_suggestion.ocr-suggestion
- <token>:ocr_correction.ocr-correction

sparv:
compression: none
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[project]
name = "sparv-ocr-suggestion-plugin"
name = "sparv-ocr-correction-plugin"
version = "0.1.0"
description = "A sparv plugin for computing suggested OCR improvements."
authors = [
Expand Down Expand Up @@ -30,18 +30,18 @@ requires = ["hatchling"]
build-backend = "hatchling.build"

[project.entry-points."sparv.plugin"]
ocr_suggestion = "ocr_suggestion"
ocr_correction = "ocr_correction"

[project.urls]
Homepage = "https://spraakbanken.gu.se"
Repository = "https://github.com/spraakbanken/sparv-ocr-suggestion-plugin"
"Bug Tracker" = "https://github.com/spraakbanken/sparv-ocr-suggestion-plugin/issues"
Repository = "https://github.com/spraakbanken/sparv-ocr-correction-plugin"
"Bug Tracker" = "https://github.com/spraakbanken/sparv-ocr-correction-plugin/issues"

[tool.hatch.build.targets.sdist]
exclude = ["/.github", "/docs"]

[tool.hatch.build.targets.wheel]
packages = ["src/ocr_suggestion"]
packages = ["src/ocr_correction"]

[tool.hatch.metadata]
allow-direct-references = true
Expand Down
34 changes: 17 additions & 17 deletions src/ocr_suggestion/__init__.py → src/ocr_correction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
DEFAULT_TOKENIZER_NAME = "google/byt5-small"
__config__ = [
Config(
"ocr_suggestion.model",
"ocr_correction.model",
description="Huggingface pretrained model name",
default=DEFAULT_MODEL_NAME,
),
Config(
"ocr_suggestion.tokenizer",
"ocr_correction.tokenizer",
description="HuggingFace pretrained tokenizer name",
default=DEFAULT_TOKENIZER_NAME,
),
Expand All @@ -40,35 +40,35 @@
@annotator(
"Word neighbour tagging with a masked Bert model",
)
def annotate_ocr_suggestion(
out_ocr_suggestion: Output = Output(
"<token>:ocr_suggestion.ocr-suggestion",
cls="ocr_suggestion",
def annotate_ocr_correction(
out_ocr_correction: Output = Output(
"<token>:ocr_correction.ocr-correction",
cls="ocr_correction",
description="Neighbours from masked BERT (format: '|<word>:<score>|...|)",
),
word: Annotation = Annotation("<token:word>"),
sentence: Annotation = Annotation("<sentence>"),
model_name: str = Config("ocr_suggestion.model"),
tokenizer_name: str = Config("ocr_suggestion.tokenizer"),
model_name: str = Config("ocr_correction.model"),
tokenizer_name: str = Config("ocr_correction.tokenizer"),
) -> None:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
ocr_suggestor = OcrSuggestor(model=model, tokenizer=tokenizer)

sentences, _orphans = sentence.get_children(word)
token_word = list(word.read())
out_ocr_suggestion_annotation = word.create_empty_attribute()
out_ocr_correction_annotation = word.create_empty_attribute()

logger.progress(total=len(sentences)) # type: ignore
for sent in sentences:
logger.progress() # type: ignore
sent_to_tag = [token_word[token_index] for token_index in sent]

ocr_suggestions = ocr_suggestor.calculate_suggestions(sent_to_tag)
out_ocr_suggestion_annotation[:] = ocr_suggestions
ocr_corrections = ocr_suggestor.calculate_suggestions(sent_to_tag)
out_ocr_correction_annotation[:] = ocr_corrections

logger.info("writing annotations")
out_ocr_suggestion.write(out_ocr_suggestion_annotation)
out_ocr_correction.write(out_ocr_correction_annotation)


class OcrSuggestor:
Expand All @@ -86,7 +86,7 @@ def calculate_suggestions(self, text: list[str]) -> list[Optional[str]]:
parts = []
curr_part: list[str] = []
curr_len = 0
ocr_suggestions: list[str] = []
ocr_corrections: list[str] = []
for word in text:
len_word = bytes_length(word)
if (curr_len + len_word + 1) > self.TEXT_LIMIT:
Expand All @@ -101,11 +101,11 @@ def calculate_suggestions(self, text: list[str]) -> list[Optional[str]]:
suggested_text = self.pipeline(part)[0]["generated_text"]
suggested_text = suggested_text.replace(",", " ,")
suggested_text = suggested_text.replace(".", " .")
ocr_suggestions = ocr_suggestions + suggested_text.split(TOK_SEP)
ocr_corrections = ocr_corrections + suggested_text.split(TOK_SEP)

if len(text) == len(ocr_suggestions) + 1 and text[-1] != ocr_suggestions[-1]:
ocr_suggestions.append(text[-1])
return zip_and_diff(text, ocr_suggestions)
if len(text) == len(ocr_corrections) + 1 and text[-1] != ocr_corrections[-1]:
ocr_corrections.append(text[-1])
return zip_and_diff(text, ocr_corrections)


def zip_and_diff(orig: list[str], sugg: list[str]) -> list[Optional[str]]:
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from ocr_suggestion import (
from ocr_correction import (
DEFAULT_MODEL_NAME,
DEFAULT_TOKENIZER_NAME,
OcrSuggestor,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ocr_suggestor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ocr_suggestion import OcrSuggestor
from ocr_correction import OcrSuggestor


def test_short_text(ocr_suggestor: OcrSuggestor):
Expand Down

0 comments on commit a4ad9ea

Please sign in to comment.