Skip to content

Commit

Permalink
Add option to trust remote code
Browse files Browse the repository at this point in the history
  • Loading branch information
OskarLiew committed Oct 7, 2024
1 parent ad6d868 commit dfee6ae
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ async def lifespan(app: FastAPI) -> AsyncIterator[None]:
"""Setup and teardown events of the app"""
# Setup
config = Config()
app.state.model = SimilarityClassifierModel(model_name=config.reranker_model_name)
app.state.model = SimilarityClassifierModel(
model_name=config.reranker_model_name,
trust_remote_code=config.trust_remote_code,
)
app.state.api_key = config.api_key
yield
# Teardown
Expand Down
1 change: 1 addition & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class Config(BaseSettings):
reranker_model_name: str = "BAAI/bge-reranker-v2-m3"
api_key: str | None = None
trust_remote_code: bool = False


def get_log_config(level: str = "INFO"):
Expand Down
12 changes: 9 additions & 3 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,16 @@


class SimilarityClassifierModel:
def __init__(self, model_name: str) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=CACHE_DIR)
def __init__(self, model_name: str, trust_remote_code: bool = False) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=CACHE_DIR,
trust_remote_code=trust_remote_code,
)
self.model = AutoModelForSequenceClassification.from_pretrained(
model_name, cache_dir=CACHE_DIR
model_name,
cache_dir=CACHE_DIR,
trust_remote_code=trust_remote_code,
)

self.device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down

0 comments on commit dfee6ae

Please sign in to comment.