Skip to content

Commit

Permalink
Merge pull request #54 from spraakbanken/use-gpu
Browse files Browse the repository at this point in the history
feat: use gpu if available
  • Loading branch information
kod-kristoff authored Nov 22, 2024
2 parents b4f7ba1 + cb24252 commit aa40ac4
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 76 deletions.
25 changes: 0 additions & 25 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,7 @@ jobs:
version: ${{ env.UV_VERSION }}
enable-cache: true

- name: Load cached venv
id: cached-venv
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}-${{ hashFiles('**/uv.lock') }}

- name: Install dependencies
if: steps.cached-venv.outputs.cache-hit != 'true'
run: make install-dev

- name: check formatting
Expand Down Expand Up @@ -80,15 +72,7 @@ jobs:
version: ${{ env.UV_VERSION }}
enable-cache: true

- name: Load cached venv
id: cached-venv
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}-${{ hashFiles('**/uv.lock') }}

- name: Install dependencies
if: steps.cached-venv.outputs.cache-hit != 'true'
run: make install-dev

- name: lint code
Expand Down Expand Up @@ -116,15 +100,7 @@ jobs:
version: ${{ env.UV_VERSION }}
enable-cache: true

- name: Load cached venv
id: cached-venv
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml') }}-${{ hashFiles('**/uv.lock') }}

- name: Install dependencies
if: steps.cached-venv.outputs.cache-hit != 'true'
run: make install-dev

- name: type-check code
Expand All @@ -144,4 +120,3 @@ jobs:
uses: re-actors/alls-green@release/v1
with:
jobs: ${{ toJSON(needs) }}
allowed-failures: upload-coverage
34 changes: 1 addition & 33 deletions .github/workflows/scheduled.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,38 +21,6 @@ env:
UV_VERSION: "0.5.1"

jobs:
# https://twitter.com/mycoliza/status/1571295690063753218
nightly:
runs-on: ubuntu-latest
name: ubuntu / 3.14-dev
steps:
- uses: actions/checkout@v4
with:
submodules: true

- name: Set up uv
uses: astral-sh/setup-uv@v3
with:
version: ${{ env.UV_VERSION }}
enable-cache: true

- name: Install python
uses: actions/setup-python@v5
with:
python-version: "3.14-dev"

- run: python --version

- name: uv lock
if: hashFiles('uv.lock') == ''
run: uv lock

- name: uv sync --dev
run: uv sync --dev

- name: make test
run: make test

# https://twitter.com/alcuadrado/status/1571291687837732873
update:
# This action checks that updating the dependencies of this crate to the latest available that
Expand Down Expand Up @@ -85,7 +53,7 @@ jobs:

- name: uv sync --dev --upgrade
if: hashFiles('uv.lock') != ''
run: uv sync --dev --upgrade
run: uv sync --dev --upgrade --all-packages --all-extras

- name: make test
if: hashFiles('uv.lock') != ''
Expand Down
20 changes: 7 additions & 13 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,26 +48,20 @@ jobs:
with:
submodules: true

- name: Set up uv
uses: astral-sh/setup-uv@v3
with:
version: ${{ env.UV_VERSION }}
enable-cache: true

- name: Set up Python ${{ matrix.python-version }}
id: setup-python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Set up uv
run: curl -LsSf https://astral.sh/uv/${{ env.UV_VERSION }}/install.sh | sh

- name: Restore uv cache
uses: actions/cache@v4
with:
path: ${{ env.UV_CACHE_DIR }}
key: uv-${{ runner.os }}-${{ hashFiles('uv.lock') }}
restore-keys: |
uv-${{ runner.os }}-${{ hashFiles('uv.lock') }}
uv-${{ runner.os }}
- name: Install dependencies
run: uv sync --all-extras --all-packages --dev
run: make install-dev

- name: Run tests for coverage
run: make test-w-coverage cov_report=xml
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ dev: install-dev

# setup development environment
install-dev:
uv sync --all-packages --group dev --verbose
uv sync --all-packages --dev

# setup production environment
install:
uv sync --no-dev
uv sync --all-packages --no-dev

lock: uv.lock

Expand Down Expand Up @@ -152,7 +152,7 @@ prepare-release: update-changelog tests/requirements-testing.lock

# we use lock extension so that dependabot doesn't pick up changes in this file
tests/requirements-testing.lock: pyproject.toml
uv export --dev --format requirements-txt --no-hashes --output-file $@
uv export --dev --format requirements-txt --no-hashes --no-emit-project --output-file $@

.PHONY: update-changelog
update-changelog: CHANGELOG.md
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from typing import Any, Optional

import torch
from parallel_corpus import graph
from parallel_corpus.text_token import Token
from sparv import api as sparv_api # type: ignore [import-untyped]
Expand Down Expand Up @@ -37,13 +38,30 @@ def __init__(self, *, tokenizer: Any, model: Any) -> None:
"""Construct an OcrCorrector."""
self.tokenizer = tokenizer
self.model = model
self.pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
if torch.cuda.is_available():
logger.info("Using GPU (cuda)")
dtype = torch.float16
else:
logger.warning("Using CPU, is cuda available?")
dtype = torch.float32
device_map = "auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None
self.pipeline = pipeline(
"text2text-generation", model=model, tokenizer=tokenizer, device_map=device_map, torch_dtype=dtype
)

@classmethod
def default(cls) -> "OcrCorrector":
"""Create a default OcrCorrector."""
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, revision=TOKENIZER_REVISION)
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, revision=MODEL_REVISION)
model = T5ForConditionalGeneration.from_pretrained(
MODEL_NAME,
revision=MODEL_REVISION,
torch_dtype=dtype,
device_map=("auto" if torch.cuda.is_available() and torch.cuda.device_count() > 1 else None),
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda() # type: ignore
return cls(model=model, tokenizer=tokenizer)

def calculate_corrections(self, text: list[str]) -> list[tuple[tuple[int, int], Optional[str]]]:
Expand Down

0 comments on commit aa40ac4

Please sign in to comment.