Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pull ollama models if not exist automatically #512

Merged
merged 7 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wren-ai-service/.env.dev.example
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ WREN_UI_ENDPOINT=http://localhost:3000
WREN_IBIS_ENDPOINT=http://localhost:8000
WREN_IBIS_SOURCE=bigquery
WREN_IBIS_MANIFEST= # this is a base64 encoded string of the MDL
WREN_IBIS_CONNECTION_INFO={"project_id": "", "dataset_id": "", "credentials":""}
WREN_IBIS_CONNECTION_INFO= # this is a base64 encode string of the connection info

# evaluation related
DATASET_NAME=book_2
Expand Down
366 changes: 190 additions & 176 deletions wren-ai-service/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ sf-hamilton = {version = "==1.63.0", extras = ["visualization"]}
aiohttp = "==3.9.5"
ollama-haystack = "==0.0.6"
langfuse = "==2.35.0"
ollama = "==0.2.1"

[tool.poetry.group.dev.dependencies]
pytest = "==8.2.0"
Expand Down
3 changes: 2 additions & 1 deletion wren-ai-service/src/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def health():
host=server_host,
port=server_port,
reload=should_reload,
reload_dirs=["src"] if should_reload else None,
reload_dirs=["src"],
reload_includes=[".env.dev"],
workers=1,
loop="uvloop",
http="httptools",
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/src/providers/embedder/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tqdm import tqdm

from src.core.provider import EmbedderProvider
from src.providers.loader import provider
from src.providers.loader import provider, pull_ollama_model
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -167,6 +167,8 @@ def __init__(
self._url = remove_trailing_slash(url)
self._embedding_model = embedding_model

pull_ollama_model(self._url, self._embedding_model)

logger.info(f"Using Ollama Embedding Model: {self._embedding_model}")
logger.info(f"Using Ollama URL: {self._url}")

Expand Down
7 changes: 5 additions & 2 deletions wren-ai-service/src/providers/engine/wren.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
import logging
import os
from typing import Any, Dict, Optional, Tuple
Expand Down Expand Up @@ -54,8 +55,10 @@ async def dry_run_sql(
"source": os.getenv("WREN_IBIS_SOURCE"),
"manifest": os.getenv("WREN_IBIS_MANIFEST"),
"connection_info": orjson.loads(
os.getenv("WREN_IBIS_CONNECTION_INFO", "{}")
),
base64.b64decode(os.getenv("WREN_IBIS_CONNECTION_INFO"))
)
if os.getenv("WREN_IBIS_CONNECTION_INFO")
else {},
paopa marked this conversation as resolved.
Show resolved Hide resolved
},
) -> Tuple[bool, Optional[Dict[str, Any]]]:
async with session.post(
Expand Down
4 changes: 3 additions & 1 deletion wren-ai-service/src/providers/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from haystack_integrations.components.generators.ollama import OllamaGenerator

from src.core.provider import LLMProvider
from src.providers.loader import provider
from src.providers.loader import provider, pull_ollama_model
from src.utils import remove_trailing_slash

logger = logging.getLogger("wren-ai-service")
Expand Down Expand Up @@ -130,6 +130,8 @@ def __init__(
self._url = remove_trailing_slash(url)
self._generation_model = generation_model

pull_ollama_model(self._url, self._generation_model)

logger.info(f"Using Ollama LLM: {self._generation_model}")
logger.info(f"Using Ollama URL: {self._url}")

Expand Down
18 changes: 18 additions & 0 deletions wren-ai-service/src/providers/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import pkgutil

from ollama import Client

logger = logging.getLogger("wren-ai-service")


Expand Down Expand Up @@ -98,3 +100,19 @@ def get_default_embedding_model_dim(embedder_provider: str):
return importlib.import_module(
f"src.providers.embedder.{file_name}"
).EMBEDDING_MODEL_DIMENSION


def pull_ollama_model(url: str, model_name: str):
client = Client(host=url)
models = client.list()["models"]
if model_name not in models:
logger.info(f"Pulling Ollama model {model_name}")
percentage = 0
for progress in client.pull(model_name, stream=True):
if "completed" in progress and "total" in progress:
new_percentage = int(progress["completed"] / progress["total"] * 100)
if new_percentage > percentage:
percentage = new_percentage
logger.info(f"Pulling Ollama model {model_name}: {percentage}%")
paopa marked this conversation as resolved.
Show resolved Hide resolved
else:
logger.info(f"Ollama model {model_name} already exists")
9 changes: 7 additions & 2 deletions wren-launcher/commands/launch.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,12 @@ func Launch() {
panic(err)
}

// wait for 10 seconds
pterm.Info.Println("Wren AI is starting, please wait for a moment...")
if llmProvider == "Custom" {
pterm.Info.Println("If you choose Ollama as LLM provider, please make sure you have started the Ollama service first. Also, Wren AI will automatically pull your chosen models if you have not done so. You can check the progress by executing `docker logs -f wrenai-wren-ai-service-1` in the terminal.")
}
url := fmt.Sprintf("http://localhost:%d", uiPort)
// wait until checking if CheckWrenAIStarted return without error
// wait until checking if CheckUIServiceStarted return without error
// if timeout 2 minutes, panic
timeoutTime := time.Now().Add(2 * time.Minute)
for {
Expand All @@ -230,6 +232,9 @@ func Launch() {
time.Sleep(5 * time.Second)
}

// wait until checking if CheckWrenAIStarted return without error
// if timeout 30 minutes, panic
timeoutTime = time.Now().Add(30 * time.Minute)
for {
if time.Now().After(timeoutTime) {
panic("Timeout")
Expand Down
Loading