diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index c676a96..10e47ef 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -44,15 +44,7 @@ jobs: version: ${{ env.UV_VERSION }} enable-cache: true - - name: Load cached venv - id: cached-venv - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}-${{ hashFiles('**/uv.lock') }} - - name: Install dependencies - if: steps.cached-venv.outputs.cache-hit != 'true' run: make install-dev - name: check formatting @@ -80,15 +72,7 @@ jobs: version: ${{ env.UV_VERSION }} enable-cache: true - - name: Load cached venv - id: cached-venv - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}-${{ hashFiles('**/uv.lock') }} - - name: Install dependencies - if: steps.cached-venv.outputs.cache-hit != 'true' run: make install-dev - name: lint code @@ -116,15 +100,7 @@ jobs: version: ${{ env.UV_VERSION }} enable-cache: true - - name: Load cached venv - id: cached-venv - uses: actions/cache@v4 - with: - path: .venv - key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}-${{ hashFiles('**/uv.lock') }} - - name: Install dependencies - if: steps.cached-venv.outputs.cache-hit != 'true' run: make install-dev - name: type-check code @@ -144,4 +120,3 @@ jobs: uses: re-actors/alls-green@release/v1 with: jobs: ${{ toJSON(needs) }} - allowed-failures: upload-coverage diff --git a/.github/workflows/scheduled.yml b/.github/workflows/scheduled.yml index ab75168..6a88bb0 100644 --- a/.github/workflows/scheduled.yml +++ b/.github/workflows/scheduled.yml @@ -21,38 +21,6 @@ env: UV_VERSION: "0.5.1" jobs: - # https://twitter.com/mycoliza/status/1571295690063753218 - nightly: - runs-on: ubuntu-latest - name: ubuntu / 3.14-dev - steps: - - uses: actions/checkout@v4 - with: - submodules: true - - - name: Set up uv - uses: astral-sh/setup-uv@v3 - with: - version: ${{ env.UV_VERSION }} - enable-cache: true - - - name: Install python - uses: actions/setup-python@v5 - with: - python-version: "3.14-dev" - - - run: python --version - - - name: uv lock - if: hashFiles('uv.lock') == '' - run: uv lock - - - name: uv sync --dev - run: uv sync --dev - - - name: make test - run: make test - # https://twitter.com/alcuadrado/status/1571291687837732873 update: # This action checks that updating the dependencies of this crate to the latest available that @@ -85,7 +53,7 @@ jobs: - name: uv sync --dev --upgrade if: hashFiles('uv.lock') != '' - run: uv sync --dev --upgrade + run: uv sync --dev --upgrade --all-packages --all-extras - name: make test if: hashFiles('uv.lock') != '' diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a748cb3..0b6f38b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -48,26 +48,20 @@ jobs: with: submodules: true + - name: Set up uv + uses: astral-sh/setup-uv@v3 + with: + version: ${{ env.UV_VERSION }} + enable-cache: true + - name: Set up Python ${{ matrix.python-version }} id: setup-python uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Set up uv - run: curl -LsSf https://astral.sh/uv/${{ env.UV_VERSION }}/install.sh | sh - - - name: Restore uv cache - uses: actions/cache@v4 - with: - path: ${{ env.UV_CACHE_DIR }} - key: uv-${{ runner.os }}-${{ hashFiles('uv.lock') }} - restore-keys: | - uv-${{ runner.os }}-${{ hashFiles('uv.lock') }} - uv-${{ runner.os }} - - name: Install dependencies - run: uv sync --all-extras --all-packages --dev + run: make install-dev - name: Run tests for coverage run: make test-w-coverage cov_report=xml diff --git a/Makefile b/Makefile index a569233..6e17166 100644 --- a/Makefile +++ b/Makefile @@ -84,11 +84,11 @@ dev: install-dev # setup development environment install-dev: - uv sync --all-packages --group dev --verbose + uv sync --all-packages --dev # setup production environment install: - uv sync --no-dev + uv sync --all-packages --no-dev lock: uv.lock @@ -152,7 +152,7 @@ prepare-release: update-changelog tests/requirements-testing.lock # we use lock extension so that dependabot doesn't pick up changes in this file tests/requirements-testing.lock: pyproject.toml - uv export --dev --format requirements-txt --no-hashes --output-file $@ + uv export --dev --format requirements-txt --no-hashes --no-emit-project --output-file $@ .PHONY: update-changelog update-changelog: CHANGELOG.md diff --git a/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/ocr_corrector.py b/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/ocr_corrector.py index f0957ab..1acc675 100644 --- a/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/ocr_corrector.py +++ b/ocr-correction-viklofg-sweocr/src/sbx_ocr_correction_viklofg_sweocr/ocr_corrector.py @@ -3,6 +3,7 @@ import re from typing import Any, Optional +import torch from parallel_corpus import graph from parallel_corpus.text_token import Token from sparv import api as sparv_api # type: ignore [import-untyped] @@ -37,13 +38,30 @@ def __init__(self, *, tokenizer: Any, model: Any) -> None: """Construct an OcrCorrector.""" self.tokenizer = tokenizer self.model = model - self.pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer) + if torch.cuda.is_available(): + logger.info("Using GPU (cuda)") + dtype = torch.float16 + else: + logger.warning("Using CPU, is cuda available?") + dtype = torch.float32 + device_map = "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None + self.pipeline = pipeline( + "text2text-generation", model=model, tokenizer=tokenizer, device_map=device_map, torch_dtype=dtype + ) @classmethod def default(cls) -> "OcrCorrector": """Create a default OcrCorrector.""" + dtype = torch.float16 if torch.cuda.is_available() else torch.float32 tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, revision=TOKENIZER_REVISION) - model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, revision=MODEL_REVISION) + model = T5ForConditionalGeneration.from_pretrained( + MODEL_NAME, + revision=MODEL_REVISION, + torch_dtype=dtype, + device_map=("auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None), + ) + if torch.cuda.is_available() and torch.cuda.device_count() == 1: + model = model.cuda() # type: ignore return cls(model=model, tokenizer=tokenizer) def calculate_corrections(self, text: list[str]) -> list[tuple[tuple[int, int], Optional[str]]]: