Skip to content

Commit

Permalink
fix wren-ai-service 0.6.0 issues (#445)
Browse files Browse the repository at this point in the history
* fix typo

* fix missing env and recreate index by default

* fix tests

* clean document store when init

* fix test

* set QDRANT_HOST default value

* logging apibase

* read env files

* fix loading .env.ai issues

* change default ollama model to llama3:70b

* remove QDRANT_HOST

* make .env.ai optional

* refine wording

* rename .env.ai to .env.ai.old to old files

* add comment
  • Loading branch information
cyyeh authored Jun 27, 2024
1 parent 937f2ab commit 16b475c
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 18 deletions.
2 changes: 2 additions & 0 deletions docker/.env.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
COMPOSE_PROJECT_NAME=wren
PLATFORM=linux/amd64

PROJECT_DIR=.

# service port
WREN_ENGINE_PORT=8080
WREN_ENGINE_SQL_PORT=7432
Expand Down
8 changes: 2 additions & 6 deletions docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,5 @@ Path structure as following:
1. copy `.env.example` to `.env.local` and modify the OpenAI API key.
2. (optional) copy `.env.ai.example` to `.env.ai` and fill in necessary information if you would like to use custom LLM.
3. (optional) if your port 3000 is occupied, you can modify the `HOST_PORT` in `.env.example`.
4. start all services:
- using OpenAI: `docker-compose --env-file .env.local up -d`
- using custom LLM: `docker-compose --env-file .env.local --env-file .env.ai up -d`
5. stop all services:
- using OpenAI: `docker-compose --env-file .env.local down`
- using custom LLM: `docker-compose --env-file .env.local --env-file .env.ai down`
4. start all services: `docker-compose --env-file .env.local up -d`
5. stop all services: `docker-compose --env-file .env.local down`
3 changes: 3 additions & 0 deletions docker/docker-compose-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ services:
# sometimes the console won't show print messages,
# using PYTHONUNBUFFERED: 1 can fix this
PYTHONUNBUFFERED: 1
env_file:
- path: ${PROJECT_DIR}/.env.ai
required: false
networks:
- wren
depends_on:
Expand Down
3 changes: 3 additions & 0 deletions docker/docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ services:
# sometimes the console won't show print messages,
# using PYTHONUNBUFFERED: 1 can fix this
PYTHONUNBUFFERED: 1
env_file:
- path: ${PROJECT_DIR}/.env.ai
required: false
networks:
- wren
depends_on:
Expand Down
7 changes: 7 additions & 0 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ def init_globals():

llm_provider, document_store_provider, engine = init_providers()

# Recreate the document store to ensure a clean slate
# TODO: for SaaS, we need to use a flag to prevent this collection_recreation
document_store_provider.get_store(recreate_index=True)
document_store_provider.get_store(
dataset_name="view_questions", recreate_index=True
)

SEMANTIC_SERVICE = SemanticsService(
pipelines={
"generate_description": description.Generation(
Expand Down
5 changes: 2 additions & 3 deletions wren-ai-service/src/providers/document_store/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ async def run(

@provider("qdrant")
class QdrantProvider(DocumentStoreProvider):
def __init__(self, location: str = os.getenv("QDRANT_HOST")):
def __init__(self, location: str = os.getenv("QDRANT_HOST", "qdrant")):
self._location = location

def get_store(
Expand All @@ -203,7 +203,7 @@ def get_store(
if os.getenv("EMBEDDING_MODEL_DIMENSION")
else 0
)
or get_default_embedding_model_dim(os.getenv("LLM_PROVIDER", "opeani")),
or get_default_embedding_model_dim(os.getenv("LLM_PROVIDER", "openai")),
dataset_name: Optional[str] = None,
recreate_index: bool = False,
):
Expand All @@ -212,7 +212,6 @@ def get_store(
embedding_dim=embedding_model_dim,
index=dataset_name or "Document",
recreate_index=recreate_index,
# hnsw_config={"ef_construct": 200, "m": 32}, # https://qdrant.tech/documentation/concepts/indexing/#vector-index
)

def get_retriever(
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/providers/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
logger = logging.getLogger("wren-ai-service")

OLLAMA_URL = "http://localhost:11434"
GENERATION_MODEL_NAME = "llama3:8b"
GENERATION_MODEL_NAME = "llama3:70b"
GENERATION_MODEL_KWARGS = {
"temperature": 0,
}
Expand Down
1 change: 1 addition & 0 deletions wren-ai-service/src/providers/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def _verify_api_key(api_key: str, api_base: str) -> None:
"""
OpenAI(api_key=api_key, base_url=api_base).models.list()

logger.info(f"Initializing OpenAILLM provider with API base: {api_base}")
# TODO: currently only OpenAI api key can be verified
if api_base == OPENAI_API_BASE:
_verify_api_key(api_key.resolve_value(), api_base)
Expand Down
9 changes: 5 additions & 4 deletions wren-ai-service/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@ def init_providers() -> Tuple[LLMProvider, DocumentStoreProvider, Engine]:
logger.info("Initializing providers...")
loader.import_mods()

llm_provider = loader.get_provider(os.getenv("LLM_PROVIDER", "openai"))
llm_provider = loader.get_provider(os.getenv("LLM_PROVIDER", "openai"))()
document_store_provider = loader.get_provider(
os.getenv("DOCUMENT_STORE_PROVIDER", "qdrant")
)
engine = loader.get_provider(os.getenv("ENGINE", "wren-ui"))
return llm_provider(), document_store_provider(), engine()
)()
engine = loader.get_provider(os.getenv("ENGINE", "wren-ui"))()

return llm_provider, document_store_provider, engine


def timer(func):
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/tests/pytest/pipelines/test_ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def test_indexing_pipeline(
assert document_store_provider.get_store().count_documents() == 3
assert (
document_store_provider.get_store(
dataset_name="view_questions"
dataset_name="view_questions",
).count_documents()
== 1
)
Expand Down
22 changes: 22 additions & 0 deletions wren-launcher/commands/launch.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,23 @@ func isEnvFileValidForCustomLLM(projectDir string) error {
return nil
}

func renameEnvFileForCustomLLM(projectDir string) error {
// rename .env.ai file to .env.ai.old
envFilePath := path.Join(projectDir, ".env.ai")

if _, err := os.Stat(envFilePath); err == nil {
newEnvFilePath := path.Join(projectDir, ".env.ai.old")
renameErr := os.Rename(envFilePath, newEnvFilePath)
if renameErr != nil {
return renameErr
}

return nil
}

return nil
}


func Launch() {
// recover from panic
Expand All @@ -151,6 +168,11 @@ func Launch() {
openaiApiKey := ""
openaiGenerationModel := ""
if llmProvider == "OpenAI" {
err = renameEnvFileForCustomLLM(projectDir)
if err != nil {
panic(err)
}

// ask for OpenAI API key
pterm.Print("\n")
openaiApiKey, _ = askForAPIKey()
Expand Down
11 changes: 8 additions & 3 deletions wren-launcher/utils/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@ const (
PG_USERNAME string = "wren-user"
)

func replaceEnvFileContent(content string, openaiApiKey string, openAIGenerationModel string, hostPort int, aiPort int, pg_password string, userUUID string, telemetryEnabled bool) string {
func replaceEnvFileContent(content string, projectDir string, openaiApiKey string, openAIGenerationModel string, hostPort int, aiPort int, pg_password string, userUUID string, telemetryEnabled bool) string {
// rplace PROJECT_DIR
reg := regexp.MustCompile(`PROJECT_DIR=(.*)`)
str := reg.ReplaceAllString(content, "PROJECT_DIR="+projectDir)

// replace OPENAI_API_KEY
reg := regexp.MustCompile(`OPENAI_API_KEY=sk-(.*)`)
str := reg.ReplaceAllString(content, "OPENAI_API_KEY="+openaiApiKey)
reg = regexp.MustCompile(`OPENAI_API_KEY=sk-(.*)`)
str = reg.ReplaceAllString(content, "OPENAI_API_KEY="+openaiApiKey)

// replace GENERATION_MODEL
reg = regexp.MustCompile(`GENERATION_MODEL=(.*)`)
Expand Down Expand Up @@ -189,6 +193,7 @@ func PrepareDockerFiles(openaiApiKey string, openaiGenerationModel string, hostP
// replace the content with regex
envFileContent := replaceEnvFileContent(
string(envExampleFileContent),
projectDir,
openaiApiKey,
openaiGenerationModel,
hostPort,
Expand Down

0 comments on commit 16b475c

Please sign in to comment.