Skip to content

Commit

Permalink
Cast to device
Browse files Browse the repository at this point in the history
  • Loading branch information
OskarLiew committed Oct 4, 2024
1 parent 318ee0d commit ad6d868
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/inference.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
import logging

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

CACHE_DIR = "/app/hf_cache"


logger = logging.getLogger(__name__)


class SimilarityClassifierModel:
def __init__(self, model_name: str) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name, cache_dir=CACHE_DIR
)

self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model.eval()

logger.info(f"Loaded model {model_name} on device {self.model.device}")

def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
with torch.no_grad():
inputs = self.tokenizer(
Expand All @@ -20,6 +30,6 @@ def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
truncation=True,
return_tensors="pt",
max_length=1024,
)
).to(self.device)
scores = self.model(**inputs, return_dict=True).logits.view(-1).float()
return scores

0 comments on commit ad6d868

Please sign in to comment.