Skip to content

Commit

Permalink
LLM reranker
Browse files Browse the repository at this point in the history
  • Loading branch information
OskarLiew committed Oct 7, 2024
1 parent d9d4c16 commit 5612fdd
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 9 deletions.
14 changes: 9 additions & 5 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from config import Config, get_log_config
from fastapi import Depends, FastAPI, HTTPException, Request, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from inference import SimilarityClassifierModel
from inference import SimilarityClassifierLLM, SimilarityClassifierModel
from models import PredictRequestModel, PredictResponseModel


Expand All @@ -14,10 +14,14 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Setup and teardown events of the app"""
# Setup
config = Config()
app.state.model = SimilarityClassifierModel(
model_name=config.reranker_model_name,
trust_remote_code=config.trust_remote_code,
)
if config.reranker_type == "CrossEncoder":
SimilarityClassifier = SimilarityClassifierModel
elif config.reranker_type == "LLM":
SimilarityClassifier = SimilarityClassifierLLM
else:
raise ValueError(f"No reranker type '{config.reranker_type}'")

app.state.model = SimilarityClassifier(model_name=config.reranker_model_name)
app.state.api_key = config.api_key
yield
# Teardown
Expand Down
2 changes: 2 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Literal

from pydantic_settings import BaseSettings

Expand All @@ -7,6 +8,7 @@ class Config(BaseSettings):
reranker_model_name: str = "BAAI/bge-reranker-v2-m3"
api_key: str | None = None
trust_remote_code: bool = False
reranker_type: Literal["CrossEncoder", "LLM"] = "CrossEncoder"


def get_log_config(level: str = "INFO"):
Expand Down
88 changes: 84 additions & 4 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,95 @@
import logging

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoModelForSequenceClassification,
AutoTokenizer,
)

CACHE_DIR = "/app/hf_cache"


logger = logging.getLogger(__name__)


class SimilarityClassifierLLM:
def __init__(self, model_name: str, trust_remote_code: bool = False) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=CACHE_DIR,
trust_remote_code=trust_remote_code,
)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
cache_dir=CACHE_DIR,
trust_remote_code=trust_remote_code,
device_map="auto",
)

self.yes_loc = self.tokenizer("Yes", add_special_tokens=False)["input_ids"][0]
self.model.eval()

logger.info(f"Loaded model {model_name} on device {self.model.device}")

def get_inputs(self, pairs, max_length=8):
prompt = "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'."
sep = "\n"
prompt_inputs = self.tokenizer(
prompt, return_tensors=None, add_special_tokens=False
)["input_ids"]
sep_inputs = self.tokenizer(sep, return_tensors=None, add_special_tokens=False)[
"input_ids"
]
inputs = []
for query, passage in pairs:
query_inputs = self.tokenizer(
f"A: {query}",
return_tensors=None,
add_special_tokens=False,
max_length=max_length * 3 // 4,
truncation=True,
)
passage_inputs = self.tokenizer(
f"B: {passage}",
return_tensors=None,
add_special_tokens=False,
max_length=max_length,
truncation=True,
)
item = self.tokenizer.prepare_for_model(
[self.tokenizer.bos_token_id] + query_inputs["input_ids"],
sep_inputs + passage_inputs["input_ids"],
truncation="only_second",
max_length=max_length,
padding=False,
return_attention_mask=False,
return_token_type_ids=False,
add_special_tokens=False,
)
item["input_ids"] = item["input_ids"] + sep_inputs + prompt_inputs
item["attention_mask"] = [1] * len(item["input_ids"])
inputs.append(item)
return self.tokenizer.pad(
inputs,
padding=True,
max_length=max_length + len(sep_inputs) + len(prompt_inputs),
pad_to_multiple_of=8,
return_tensors="pt",
)

def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
with torch.no_grad():
inputs = self.get_inputs(pairs).to(self.model.device)
scores = (
self.model(**inputs, return_dict=True)
.logits[:, -1, self.yes_loc]
.view(-1)
.float()
)
return scores


class SimilarityClassifierModel:
def __init__(self, model_name: str, trust_remote_code: bool = False) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
Expand All @@ -20,10 +101,9 @@ def __init__(self, model_name: str, trust_remote_code: bool = False) -> None:
model_name,
cache_dir=CACHE_DIR,
trust_remote_code=trust_remote_code,
device_map="auto",
)

self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model.eval()

logger.info(f"Loaded model {model_name} on device {self.model.device}")
Expand All @@ -36,6 +116,6 @@ def predict(self, pairs: list[tuple[str, str]]) -> list[float]:
truncation=True,
return_tensors="pt",
max_length=1024,
).to(self.device)
).to(self.model.device)
scores = self.model(**inputs, return_dict=True).logits.view(-1).float()
return scores

0 comments on commit 5612fdd

Please sign in to comment.