Skip to content

Commit

Permalink
Add Replicate as a provider
Browse files Browse the repository at this point in the history
With this patch, the model can be executed in a cloud (Replicate).
  • Loading branch information
mareklibra committed Oct 22, 2024
1 parent e9b9355 commit c209aee
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 80 deletions.
13 changes: 11 additions & 2 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,22 @@ python main.py run
This will create a web server on port 5000, and user can use the browser to
iterate over it.

## run with Replicate
Without a GPU installed locally, it might be needed to run the models in a cloud, like the [Replicate](https://replicate.com/).

```
export export REPLICATE_API_TOKEN=r8_2.........
LLM_RUNNER=replicate LOG_LEVEL=DEBUG LLM_MODEL="granite-code:8b" python main.py run
```

# Environment Variables

| Environment Variable | Default Value | Description |
|-------------------------|--------------------------------------------------------------|-----------------------------------------------|
| `OLLAMA_MODEL` | `granite-code:8b` | Specifies the model used by Ollama. |
| `LLM_MODEL` | `granite-code:8b` | Specifies the model used by Ollama. |
| `LLM_URL` | `http://localhost:11434` | Base URL for Ollama API. |
| `LLM_RUNNER` | `ollama` | Runner (Provider) for running the models. Either "ollama" or "replicate" |
| `LOG_LEVEL` | `INFO` | LOG LEVEL information. |
| `OLLAMA_URL` | `http://localhost:11434` | Base URL for Ollama API. |
| `FAISS_DB` | `/tmp/db_faiss` | Path or reference to the local FAISS database.|
| `WORKFLOW_SCHEMA_URL` | `https://raw.githubusercontent.com/serverlessworkflow/specification/main/schema/workflow.yaml` | URL for the serverless workflow JSON schema. |
| `SQLITE` | `chats.db` | path to store previous chats |
Expand Down
6 changes: 3 additions & 3 deletions api/chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import uuid
import logging

from flask import jsonify, g, Response, request
from services.chats import get_response_for_session, get_history
from services.chats import get_workflow_for_session, get_all_workflow_for_session

from services.chats import get_response_for_session, get_history, get_all_workflow_for_session, get_workflow_for_session

def list_chats():
sessions = g.ctx.history_repo.get_all_sessions()
Expand Down Expand Up @@ -36,6 +35,7 @@ def get_all_workflow(session_id):


def push_new_message(session_id):
logging.debug("push_new_message start")
user_input = request.json.get("input", "")
if len(user_input) == 0:
return Response("Invalid user input", status=400)
Expand Down
7 changes: 4 additions & 3 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
OLLAMA_URL = "http://localhost:11434"
SPEC_URL = "https://raw.githubusercontent.com/serverlessworkflow/specification/0.8.x/specification.md" # noqa E501
SQLITE_DB = "chats.db"

LLM_RUNNER = "ollama" # or "replicate"

class Config:
def __init__(self):
self.model = self.get_env_variable('OLLAMA_MODEL', MODEL)
self.base_url = self.get_env_variable('OLLAMA_URL', OLLAMA_URL)
self.llm_runner = self.get_env_variable('LLM_RUNNER', LLM_RUNNER)
self.model = self.get_env_variable('LLM_MODEL', MODEL)
self.base_url = self.get_env_variable('LLM_URL', OLLAMA_URL)
self.db = self.get_env_variable('FAISS_DB', FAISS_DB)
self.chat_db = self.get_env_variable('SQLITE_DB', SQLITE_DB)

Expand Down
63 changes: 0 additions & 63 deletions lib/ollama.py

This file was deleted.

3 changes: 0 additions & 3 deletions lib/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ def __init__(self, path, embeddings, index_name="faissIndex", embeddings_len=409
def _initialize(self):
# Same initialization as here:
# https://github.com/langchain-ai/langchain/blob/379803751e5ae40a2aadcb4072dbb2525187dd1f/libs/community/langchain_community/vectorstores/faiss.py#L871 # noqa E501
# @TODO The 4096 is the shape size of the embeddings, currently
# hardcoded, but maybe we can get from OllamaEmbeddings class?
#index = IndexFlatL2(len(self.embeddings.embed_query("hello world")))

index = IndexFlatL2(self.embeddings_len)
self.faiss = FAISS(
Expand Down
9 changes: 5 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from lib.history import HistoryRepository
from lib.json_validator import JsonSchemaValidatorTool, JsonSchemaValidationException
from lib.models import SerVerlessWorkflow
from lib.ollama import Ollama
from lib.llm_runner import LlmRunner
from lib.repository import VectorRepository
from lib.retriever import Retriever
from lib.validator import OutputValidator
Expand All @@ -28,9 +28,10 @@
class Context:
def __init__(self, config):
self.config = config
self.ollama = Ollama(self.config.base_url, self.config.model)

self.repo = VectorRepository(self.config.db, self.ollama.embeddings, embeddings_len=MODELS_EMBEDDINGS.get(self.config.model, 4096))
self.llm_runner = LlmRunner(self.config.llm_runner, self.config.base_url, self.config.model)
self.repo = VectorRepository(self.config.db, self.llm_runner.embeddings, embeddings_len=MODELS_EMBEDDINGS.get(self.config.model, 4096))

self.validator = OutputValidator(
SerVerlessWorkflow,
JsonSchemaValidatorTool.load_from_file("lib/schema/workflow.json"))
Expand Down Expand Up @@ -72,7 +73,7 @@ def load_data(obj, file_path):
sys.exit(1)

splitter = Retriever.get_splitters(file_path)
documents = obj.ollama.parse_document(splitter, content)
documents = obj.llm_runner.parse_document(splitter, content)

if len(documents) == 0:
click.echo("The len of the documents is 0")
Expand Down
1 change: 0 additions & 1 deletion req.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ cffi==1.16.0
chardet==5.2.0
charset-normalizer==3.3.2
click==8.1.7
cryptography==42.0.5
dataclasses-json==0.6.7
decorator==5.1.1
deepdiff==8.0.1
Expand Down
7 changes: 6 additions & 1 deletion services/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def __init__(self, llm, retriever, history, session_id):
}
| prompt
| llm)
logging.debug("ChatChain building is finished")

def react(self, user_message):
prompt = self.get_react_prompt()
Expand Down Expand Up @@ -287,19 +288,23 @@ def stream(self, data):


def get_response_for_session(ctx, session_id, user_message):
logging.debug("get_response_for_session started")
retriever = ctx.repo.retriever
llm = ctx.ollama.llm
llm = ctx.llm_runner.llm
# @TODO check if this clone the thing to a new object.
history_repo = ctx.history_repo
history_repo.set_session(f"{session_id}")
yield "Waiting on AI response\n"
chain = ChatChain(llm, retriever, history_repo, session_id)
ai_response = []
result = chain.stream({"input": user_message})
logging.debug("Result received")
for x in result:
logging.debug("Passing through result")
ai_response.append(x.content)
yield str(x.content)

logging.debug("Validating json")
yield "\nChecking if json is correct and validation\n\n"
full_ai_response = "".join(ai_response)
validator = ValidatingJsonWorkflow(chain, session_id, full_ai_response, ctx.validator)
Expand Down

0 comments on commit c209aee

Please sign in to comment.