From c209aee612ec32a5e47b2ea1db0add4ae5636601 Mon Sep 17 00:00:00 2001 From: Marek Libra Date: Tue, 22 Oct 2024 15:04:25 +0200 Subject: [PATCH] Add Replicate as a provider With this patch, the model can be executed in a cloud (Replicate). --- Readme.md | 13 ++++++++-- api/chat.py | 6 ++--- config.py | 7 +++--- lib/ollama.py | 63 ----------------------------------------------- lib/repository.py | 3 --- main.py | 9 ++++--- req.txt | 1 - services/chats.py | 7 +++++- 8 files changed, 29 insertions(+), 80 deletions(-) delete mode 100644 lib/ollama.py diff --git a/Readme.md b/Readme.md index 71dfb84..3b9cce6 100644 --- a/Readme.md +++ b/Readme.md @@ -86,13 +86,22 @@ python main.py run This will create a web server on port 5000, and user can use the browser to iterate over it. +## run with Replicate +Without a GPU installed locally, it might be needed to run the models in a cloud, like the [Replicate](https://replicate.com/). + +``` +export export REPLICATE_API_TOKEN=r8_2......... +LLM_RUNNER=replicate LOG_LEVEL=DEBUG LLM_MODEL="granite-code:8b" python main.py run +``` + # Environment Variables | Environment Variable | Default Value | Description | |-------------------------|--------------------------------------------------------------|-----------------------------------------------| -| `OLLAMA_MODEL` | `granite-code:8b` | Specifies the model used by Ollama. | +| `LLM_MODEL` | `granite-code:8b` | Specifies the model used by Ollama. | +| `LLM_URL` | `http://localhost:11434` | Base URL for Ollama API. | +| `LLM_RUNNER` | `ollama` | Runner (Provider) for running the models. Either "ollama" or "replicate" | | `LOG_LEVEL` | `INFO` | LOG LEVEL information. | -| `OLLAMA_URL` | `http://localhost:11434` | Base URL for Ollama API. | | `FAISS_DB` | `/tmp/db_faiss` | Path or reference to the local FAISS database.| | `WORKFLOW_SCHEMA_URL` | `https://raw.githubusercontent.com/serverlessworkflow/specification/main/schema/workflow.yaml` | URL for the serverless workflow JSON schema. | | `SQLITE` | `chats.db` | path to store previous chats | diff --git a/api/chat.py b/api/chat.py index 3054036..8103d6e 100644 --- a/api/chat.py +++ b/api/chat.py @@ -1,9 +1,8 @@ import uuid +import logging from flask import jsonify, g, Response, request -from services.chats import get_response_for_session, get_history -from services.chats import get_workflow_for_session, get_all_workflow_for_session - +from services.chats import get_response_for_session, get_history, get_all_workflow_for_session, get_workflow_for_session def list_chats(): sessions = g.ctx.history_repo.get_all_sessions() @@ -36,6 +35,7 @@ def get_all_workflow(session_id): def push_new_message(session_id): + logging.debug("push_new_message start") user_input = request.json.get("input", "") if len(user_input) == 0: return Response("Invalid user input", status=400) diff --git a/config.py b/config.py index 27fa77e..7f53704 100644 --- a/config.py +++ b/config.py @@ -8,12 +8,13 @@ OLLAMA_URL = "http://localhost:11434" SPEC_URL = "https://raw.githubusercontent.com/serverlessworkflow/specification/0.8.x/specification.md" # noqa E501 SQLITE_DB = "chats.db" - +LLM_RUNNER = "ollama" # or "replicate" class Config: def __init__(self): - self.model = self.get_env_variable('OLLAMA_MODEL', MODEL) - self.base_url = self.get_env_variable('OLLAMA_URL', OLLAMA_URL) + self.llm_runner = self.get_env_variable('LLM_RUNNER', LLM_RUNNER) + self.model = self.get_env_variable('LLM_MODEL', MODEL) + self.base_url = self.get_env_variable('LLM_URL', OLLAMA_URL) self.db = self.get_env_variable('FAISS_DB', FAISS_DB) self.chat_db = self.get_env_variable('SQLITE_DB', SQLITE_DB) diff --git a/lib/ollama.py b/lib/ollama.py deleted file mode 100644 index 243a19b..0000000 --- a/lib/ollama.py +++ /dev/null @@ -1,63 +0,0 @@ -from langchain_community.chat_models import ChatOllama -# from langchain_experimental.llms.ollama_functions import OllamaFunctions -# from langchain_openai import OpenAIEmbeddings - -from langchain_community.embeddings import OllamaEmbeddings -# from langchain_openai import ChatOpenAI - - -class Ollama(): - def __init__(self, base_url, model): - self.base_url = base_url - self.model = model - self._embeddings = None - self._ollama = None - - @property - def llm(self): - # Normally we should use ChatOllama, but does not support - # OllamaFunctions. - # On the other hand, the OpenAI has better support for multiple actions - # in langchain, let's take advantage of it - if not self._ollama: - self._ollama = ChatOllama( - base_url=self.base_url, - model=self.model, - temperature=0, - num_ctx=20000, - ) - # self._ollama = ChatOpenAI( - # base_url=self.base_url + "/v1", - # api_key="no-need", - # model=self.model, - # temperature=0 - # ) - # self._ollama = OllamaFunctions( - # base_url=self.base_url, - # model=self.model, - # format="json") - return self._ollama - - @property - def embeddings(self): - if not self._embeddings: - self._embeddings = OllamaEmbeddings( - base_url=self.base_url, - model=self.model) - # self._embeddings = OpenAIEmbeddings( - # base_url=self.base_url + "/api", - # api_key="no-need", - # model=self.model - # ) - return self._embeddings - - @classmethod - def parse_document(cls, splitter, content): - # @TODO here the splitter can be different, and we should check the - # content[].metadata.source filename type, so: - # - Markdown: https://api.python.langchain.com/en/latest/text_splitters_api_reference.html#module-langchain_text_splitters.markdown # noqa: E501 - # - HTML: https://api.python.langchain.com/en/latest/text_splitters_api_reference.html#module-langchain_text_splitters.html # noqa: E501 - - text_splitter = splitter() - documents = text_splitter.split_documents(content) - return documents diff --git a/lib/repository.py b/lib/repository.py index 36e6f92..10bc771 100644 --- a/lib/repository.py +++ b/lib/repository.py @@ -19,9 +19,6 @@ def __init__(self, path, embeddings, index_name="faissIndex", embeddings_len=409 def _initialize(self): # Same initialization as here: # https://github.com/langchain-ai/langchain/blob/379803751e5ae40a2aadcb4072dbb2525187dd1f/libs/community/langchain_community/vectorstores/faiss.py#L871 # noqa E501 - # @TODO The 4096 is the shape size of the embeddings, currently - # hardcoded, but maybe we can get from OllamaEmbeddings class? - #index = IndexFlatL2(len(self.embeddings.embed_query("hello world"))) index = IndexFlatL2(self.embeddings_len) self.faiss = FAISS( diff --git a/main.py b/main.py index 1297b20..a5f85f0 100644 --- a/main.py +++ b/main.py @@ -10,7 +10,7 @@ from lib.history import HistoryRepository from lib.json_validator import JsonSchemaValidatorTool, JsonSchemaValidationException from lib.models import SerVerlessWorkflow -from lib.ollama import Ollama +from lib.llm_runner import LlmRunner from lib.repository import VectorRepository from lib.retriever import Retriever from lib.validator import OutputValidator @@ -28,9 +28,10 @@ class Context: def __init__(self, config): self.config = config - self.ollama = Ollama(self.config.base_url, self.config.model) - self.repo = VectorRepository(self.config.db, self.ollama.embeddings, embeddings_len=MODELS_EMBEDDINGS.get(self.config.model, 4096)) + self.llm_runner = LlmRunner(self.config.llm_runner, self.config.base_url, self.config.model) + self.repo = VectorRepository(self.config.db, self.llm_runner.embeddings, embeddings_len=MODELS_EMBEDDINGS.get(self.config.model, 4096)) + self.validator = OutputValidator( SerVerlessWorkflow, JsonSchemaValidatorTool.load_from_file("lib/schema/workflow.json")) @@ -72,7 +73,7 @@ def load_data(obj, file_path): sys.exit(1) splitter = Retriever.get_splitters(file_path) - documents = obj.ollama.parse_document(splitter, content) + documents = obj.llm_runner.parse_document(splitter, content) if len(documents) == 0: click.echo("The len of the documents is 0") diff --git a/req.txt b/req.txt index eecd22b..719b755 100644 --- a/req.txt +++ b/req.txt @@ -18,7 +18,6 @@ cffi==1.16.0 chardet==5.2.0 charset-normalizer==3.3.2 click==8.1.7 -cryptography==42.0.5 dataclasses-json==0.6.7 decorator==5.1.1 deepdiff==8.0.1 diff --git a/services/chats.py b/services/chats.py index 07e2ff0..1e08fb0 100644 --- a/services/chats.py +++ b/services/chats.py @@ -248,6 +248,7 @@ def __init__(self, llm, retriever, history, session_id): } | prompt | llm) + logging.debug("ChatChain building is finished") def react(self, user_message): prompt = self.get_react_prompt() @@ -287,8 +288,9 @@ def stream(self, data): def get_response_for_session(ctx, session_id, user_message): + logging.debug("get_response_for_session started") retriever = ctx.repo.retriever - llm = ctx.ollama.llm + llm = ctx.llm_runner.llm # @TODO check if this clone the thing to a new object. history_repo = ctx.history_repo history_repo.set_session(f"{session_id}") @@ -296,10 +298,13 @@ def get_response_for_session(ctx, session_id, user_message): chain = ChatChain(llm, retriever, history_repo, session_id) ai_response = [] result = chain.stream({"input": user_message}) + logging.debug("Result received") for x in result: + logging.debug("Passing through result") ai_response.append(x.content) yield str(x.content) + logging.debug("Validating json") yield "\nChecking if json is correct and validation\n\n" full_ai_response = "".join(ai_response) validator = ValidatingJsonWorkflow(chain, session_id, full_ai_response, ctx.validator)