From dfee6ae03ca5b5cc0eeefb195779736aea0f7373 Mon Sep 17 00:00:00 2001 From: Oskar Liew Date: Mon, 7 Oct 2024 08:49:16 +0200 Subject: [PATCH] Add option to trust remote code --- src/app.py | 5 ++++- src/config.py | 1 + src/inference.py | 12 +++++++++--- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/app.py b/src/app.py index c02c38a..79b563b 100644 --- a/src/app.py +++ b/src/app.py @@ -14,7 +14,10 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]: """Setup and teardown events of the app""" # Setup config = Config() - app.state.model = SimilarityClassifierModel(model_name=config.reranker_model_name) + app.state.model = SimilarityClassifierModel( + model_name=config.reranker_model_name, + trust_remote_code=config.trust_remote_code, + ) app.state.api_key = config.api_key yield # Teardown diff --git a/src/config.py b/src/config.py index a1579aa..1cd9275 100644 --- a/src/config.py +++ b/src/config.py @@ -6,6 +6,7 @@ class Config(BaseSettings): reranker_model_name: str = "BAAI/bge-reranker-v2-m3" api_key: str | None = None + trust_remote_code: bool = False def get_log_config(level: str = "INFO"): diff --git a/src/inference.py b/src/inference.py index 3eac554..84383b1 100644 --- a/src/inference.py +++ b/src/inference.py @@ -10,10 +10,16 @@ class SimilarityClassifierModel: - def __init__(self, model_name: str) -> None: - self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR) + def __init__(self, model_name: str, trust_remote_code: bool = False) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + cache_dir=CACHE_DIR, + trust_remote_code=trust_remote_code, + ) self.model = AutoModelForSequenceClassification.from_pretrained( - model_name, cache_dir=CACHE_DIR + model_name, + cache_dir=CACHE_DIR, + trust_remote_code=trust_remote_code, ) self.device = "cuda" if torch.cuda.is_available() else "cpu"