From 16b475c7e89230a85e76c74d94d4216e1b0f6b14 Mon Sep 17 00:00:00 2001 From: Chih-Yu Yeh Date: Thu, 27 Jun 2024 11:17:28 +0800 Subject: [PATCH] fix wren-ai-service 0.6.0 issues (#445) * fix typo * fix missing env and recreate index by default * fix tests * clean document store when init * fix test * set QDRANT_HOST default value * logging apibase * read env files * fix loading .env.ai issues * change default ollama model to llama3:70b * remove QDRANT_HOST * make .env.ai optional * refine wording * rename .env.ai to .env.ai.old to old files * add comment --- docker/.env.example | 2 ++ docker/README.md | 8 ++----- docker/docker-compose-dev.yaml | 3 +++ docker/docker-compose.yaml | 3 +++ wren-ai-service/src/globals.py | 7 ++++++ .../src/providers/document_store/qdrant.py | 5 ++--- wren-ai-service/src/providers/llm/ollama.py | 2 +- wren-ai-service/src/providers/llm/openai.py | 1 + wren-ai-service/src/utils.py | 9 ++++---- .../tests/pytest/pipelines/test_ask.py | 2 +- wren-launcher/commands/launch.go | 22 +++++++++++++++++++ wren-launcher/utils/docker.go | 11 +++++++--- 12 files changed, 57 insertions(+), 18 deletions(-) diff --git a/docker/.env.example b/docker/.env.example index 9f8557421..36b2d65d6 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -1,6 +1,8 @@ COMPOSE_PROJECT_NAME=wren PLATFORM=linux/amd64 +PROJECT_DIR=. + # service port WREN_ENGINE_PORT=8080 WREN_ENGINE_SQL_PORT=7432 diff --git a/docker/README.md b/docker/README.md index 8011d5d62..3994e87ac 100644 --- a/docker/README.md +++ b/docker/README.md @@ -22,9 +22,5 @@ Path structure as following: 1. copy `.env.example` to `.env.local` and modify the OpenAI API key. 2. (optional) copy `.env.ai.example` to `.env.ai` and fill in necessary information if you would like to use custom LLM. 3. (optional) if your port 3000 is occupied, you can modify the `HOST_PORT` in `.env.example`. -4. start all services: - - using OpenAI: `docker-compose --env-file .env.local up -d` - - using custom LLM: `docker-compose --env-file .env.local --env-file .env.ai up -d` -5. stop all services: - - using OpenAI: `docker-compose --env-file .env.local down` - - using custom LLM: `docker-compose --env-file .env.local --env-file .env.ai down` \ No newline at end of file +4. start all services: `docker-compose --env-file .env.local up -d` +5. stop all services: `docker-compose --env-file .env.local down` \ No newline at end of file diff --git a/docker/docker-compose-dev.yaml b/docker/docker-compose-dev.yaml index 8c26b9dfa..1be103243 100644 --- a/docker/docker-compose-dev.yaml +++ b/docker/docker-compose-dev.yaml @@ -50,6 +50,9 @@ services: # sometimes the console won't show print messages, # using PYTHONUNBUFFERED: 1 can fix this PYTHONUNBUFFERED: 1 + env_file: + - path: ${PROJECT_DIR}/.env.ai + required: false networks: - wren depends_on: diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index b930216bf..6cbbf70d6 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -63,6 +63,9 @@ services: # sometimes the console won't show print messages, # using PYTHONUNBUFFERED: 1 can fix this PYTHONUNBUFFERED: 1 + env_file: + - path: ${PROJECT_DIR}/.env.ai + required: false networks: - wren depends_on: diff --git a/wren-ai-service/src/globals.py b/wren-ai-service/src/globals.py index 95a18a0e6..2b070ecec 100644 --- a/wren-ai-service/src/globals.py +++ b/wren-ai-service/src/globals.py @@ -38,6 +38,13 @@ def init_globals(): llm_provider, document_store_provider, engine = init_providers() + # Recreate the document store to ensure a clean slate + # TODO: for SaaS, we need to use a flag to prevent this collection_recreation + document_store_provider.get_store(recreate_index=True) + document_store_provider.get_store( + dataset_name="view_questions", recreate_index=True + ) + SEMANTIC_SERVICE = SemanticsService( pipelines={ "generate_description": description.Generation( diff --git a/wren-ai-service/src/providers/document_store/qdrant.py b/wren-ai-service/src/providers/document_store/qdrant.py index 33b5c7dec..d94f928f1 100644 --- a/wren-ai-service/src/providers/document_store/qdrant.py +++ b/wren-ai-service/src/providers/document_store/qdrant.py @@ -193,7 +193,7 @@ async def run( @provider("qdrant") class QdrantProvider(DocumentStoreProvider): - def __init__(self, location: str = os.getenv("QDRANT_HOST")): + def __init__(self, location: str = os.getenv("QDRANT_HOST", "qdrant")): self._location = location def get_store( @@ -203,7 +203,7 @@ def get_store( if os.getenv("EMBEDDING_MODEL_DIMENSION") else 0 ) - or get_default_embedding_model_dim(os.getenv("LLM_PROVIDER", "opeani")), + or get_default_embedding_model_dim(os.getenv("LLM_PROVIDER", "openai")), dataset_name: Optional[str] = None, recreate_index: bool = False, ): @@ -212,7 +212,6 @@ def get_store( embedding_dim=embedding_model_dim, index=dataset_name or "Document", recreate_index=recreate_index, - # hnsw_config={"ef_construct": 200, "m": 32}, # https://qdrant.tech/documentation/concepts/indexing/#vector-index ) def get_retriever( diff --git a/wren-ai-service/src/providers/llm/ollama.py b/wren-ai-service/src/providers/llm/ollama.py index d06544339..42418d245 100644 --- a/wren-ai-service/src/providers/llm/ollama.py +++ b/wren-ai-service/src/providers/llm/ollama.py @@ -19,7 +19,7 @@ logger = logging.getLogger("wren-ai-service") OLLAMA_URL = "http://localhost:11434" -GENERATION_MODEL_NAME = "llama3:8b" +GENERATION_MODEL_NAME = "llama3:70b" GENERATION_MODEL_KWARGS = { "temperature": 0, } diff --git a/wren-ai-service/src/providers/llm/openai.py b/wren-ai-service/src/providers/llm/openai.py index ca9701f6d..163b01652 100644 --- a/wren-ai-service/src/providers/llm/openai.py +++ b/wren-ai-service/src/providers/llm/openai.py @@ -291,6 +291,7 @@ def _verify_api_key(api_key: str, api_base: str) -> None: """ OpenAI(api_key=api_key, base_url=api_base).models.list() + logger.info(f"Initializing OpenAILLM provider with API base: {api_base}") # TODO: currently only OpenAI api key can be verified if api_base == OPENAI_API_BASE: _verify_api_key(api_key.resolve_value(), api_base) diff --git a/wren-ai-service/src/utils.py b/wren-ai-service/src/utils.py index b26d64c80..240402070 100644 --- a/wren-ai-service/src/utils.py +++ b/wren-ai-service/src/utils.py @@ -63,12 +63,13 @@ def init_providers() -> Tuple[LLMProvider, DocumentStoreProvider, Engine]: logger.info("Initializing providers...") loader.import_mods() - llm_provider = loader.get_provider(os.getenv("LLM_PROVIDER", "openai")) + llm_provider = loader.get_provider(os.getenv("LLM_PROVIDER", "openai"))() document_store_provider = loader.get_provider( os.getenv("DOCUMENT_STORE_PROVIDER", "qdrant") - ) - engine = loader.get_provider(os.getenv("ENGINE", "wren-ui")) - return llm_provider(), document_store_provider(), engine() + )() + engine = loader.get_provider(os.getenv("ENGINE", "wren-ui"))() + + return llm_provider, document_store_provider, engine def timer(func): diff --git a/wren-ai-service/tests/pytest/pipelines/test_ask.py b/wren-ai-service/tests/pytest/pipelines/test_ask.py index a10133fcc..812332ba7 100644 --- a/wren-ai-service/tests/pytest/pipelines/test_ask.py +++ b/wren-ai-service/tests/pytest/pipelines/test_ask.py @@ -92,7 +92,7 @@ def test_indexing_pipeline( assert document_store_provider.get_store().count_documents() == 3 assert ( document_store_provider.get_store( - dataset_name="view_questions" + dataset_name="view_questions", ).count_documents() == 1 ) diff --git a/wren-launcher/commands/launch.go b/wren-launcher/commands/launch.go index b77f8d6a7..5e82177c5 100644 --- a/wren-launcher/commands/launch.go +++ b/wren-launcher/commands/launch.go @@ -125,6 +125,23 @@ func isEnvFileValidForCustomLLM(projectDir string) error { return nil } +func renameEnvFileForCustomLLM(projectDir string) error { + // rename .env.ai file to .env.ai.old + envFilePath := path.Join(projectDir, ".env.ai") + + if _, err := os.Stat(envFilePath); err == nil { + newEnvFilePath := path.Join(projectDir, ".env.ai.old") + renameErr := os.Rename(envFilePath, newEnvFilePath) + if renameErr != nil { + return renameErr + } + + return nil + } + + return nil +} + func Launch() { // recover from panic @@ -151,6 +168,11 @@ func Launch() { openaiApiKey := "" openaiGenerationModel := "" if llmProvider == "OpenAI" { + err = renameEnvFileForCustomLLM(projectDir) + if err != nil { + panic(err) + } + // ask for OpenAI API key pterm.Print("\n") openaiApiKey, _ = askForAPIKey() diff --git a/wren-launcher/utils/docker.go b/wren-launcher/utils/docker.go index 9e4b9955e..4d0bc9240 100644 --- a/wren-launcher/utils/docker.go +++ b/wren-launcher/utils/docker.go @@ -32,10 +32,14 @@ const ( PG_USERNAME string = "wren-user" ) -func replaceEnvFileContent(content string, openaiApiKey string, openAIGenerationModel string, hostPort int, aiPort int, pg_password string, userUUID string, telemetryEnabled bool) string { +func replaceEnvFileContent(content string, projectDir string, openaiApiKey string, openAIGenerationModel string, hostPort int, aiPort int, pg_password string, userUUID string, telemetryEnabled bool) string { + // rplace PROJECT_DIR + reg := regexp.MustCompile(`PROJECT_DIR=(.*)`) + str := reg.ReplaceAllString(content, "PROJECT_DIR="+projectDir) + // replace OPENAI_API_KEY - reg := regexp.MustCompile(`OPENAI_API_KEY=sk-(.*)`) - str := reg.ReplaceAllString(content, "OPENAI_API_KEY="+openaiApiKey) + reg = regexp.MustCompile(`OPENAI_API_KEY=sk-(.*)`) + str = reg.ReplaceAllString(content, "OPENAI_API_KEY="+openaiApiKey) // replace GENERATION_MODEL reg = regexp.MustCompile(`GENERATION_MODEL=(.*)`) @@ -189,6 +193,7 @@ func PrepareDockerFiles(openaiApiKey string, openaiGenerationModel string, hostP // replace the content with regex envFileContent := replaceEnvFileContent( string(envExampleFileContent), + projectDir, openaiApiKey, openaiGenerationModel, hostPort,