From a2e4e4bede177fb162567d6b2ae7ad5559ecd150 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Tue, 23 Apr 2024 16:43:34 +0530 Subject: [PATCH] Add support for Llama 3 in Khoj offline mode - Improve extract question prompts to explicitly request JSON list - Use llama-3 chat format if HF repo_id mentions llama-3. The llama-cpp-python logic for detecting when to use llama-3 chat format isn't robust enough currently --- pyproject.toml | 2 +- src/khoj/processor/conversation/offline/utils.py | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 12ef261ba..26d08c259 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ dependencies = [ "pymupdf >= 1.23.5", "django == 4.2.10", "authlib == 1.2.1", - "llama-cpp-python == 0.2.56", + "llama-cpp-python == 0.2.64", "itsdangerous == 2.1.2", "httpx == 0.25.0", "pgvector == 0.2.4", diff --git a/src/khoj/processor/conversation/offline/utils.py b/src/khoj/processor/conversation/offline/utils.py index c43c73538..05de4b9f7 100644 --- a/src/khoj/processor/conversation/offline/utils.py +++ b/src/khoj/processor/conversation/offline/utils.py @@ -2,6 +2,7 @@ import logging import math import os +from typing import Any, Dict from huggingface_hub.constants import HF_HUB_CACHE @@ -14,12 +15,16 @@ def download_model(repo_id: str, filename: str = "*Q4_K_M.gguf", max_tokens: int = None): # Initialize Model Parameters # Use n_ctx=0 to get context size from the model - kwargs = {"n_threads": 4, "n_ctx": 0, "verbose": False} + kwargs: Dict[str, Any] = {"n_threads": 4, "n_ctx": 0, "verbose": False} # Decide whether to load model to GPU or CPU device = "gpu" if state.chat_on_gpu and state.device != "cpu" else "cpu" kwargs["n_gpu_layers"] = -1 if device == "gpu" else 0 + # Add chat format if known + if "llama-3" in repo_id.lower(): + kwargs["chat_format"] = "llama-3" + # Check if the model is already downloaded model_path = load_model_from_cache(repo_id, filename) chat_model = None