Skip to content

Commit

Permalink
resolve circular dep, correct vdb url, resp_format fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Dnouv committed May 30, 2024
1 parent 6ff667b commit 6fdca85
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 15 deletions.
2 changes: 1 addition & 1 deletion core/local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def simple_qa(query: str, context: str) -> str:
temperature=0.1,
messages=messages,
stream=False,
response_format="web",
response_format={"type": "text"}, # mlc doesn't supports string "web"
)
return response.choices[0].message.content

Expand Down
10 changes: 10 additions & 0 deletions core/tasks/celery_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import redis
from celery import Celery
from typing import cast
import core.config as configs


redis_client = cast(redis.Redis, redis.Redis.from_url(configs.redis_url)) # annoyingly from_url returns None, not Self
app = Celery("tasks", broker=configs.redis_url)

app.autodiscover_tasks(["core.tasks"]) # Explicitly discover tasks in 'app' package
4 changes: 2 additions & 2 deletions core/tasks/is_ready.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .tasks import app

import socket

import requests

from core.config import litellm_url, vector_db_url

from .celery_app import app

def is_ready():
# response = requests.get(f"{litellm_url}/health", headers={
# "Authorization": f"Bearer {os.getenv('LITELLM_MASTER_KEY', '')}"
Expand Down
16 changes: 8 additions & 8 deletions core/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,15 @@
import core.config as configs

from .is_ready import is_ready

redis_client = cast(redis.Redis, redis.Redis.from_url(configs.redis_url)) # annoyingly from_url returns None, not Self
app = Celery("tasks", broker=configs.redis_url)

app.autodiscover_tasks(["core.tasks"]) # Explicitly discover tasks in 'app' package
from .celery_app import app

# Global MongoDB client
mongo_client: MongoClient
mongo_client: MongoClient = None

redis_client = cast(redis.Redis, redis.Redis.from_url(configs.redis_url)) # annoyingly from_url returns None, not Self

@signals.worker_process_init.connect
async def ensure_connections(*args, **kwargs):
def ensure_connections(*args, **kwargs):
global mongo_client
mongo_client = MongoClient(configs.mongo_url)

Expand Down Expand Up @@ -220,11 +217,12 @@ def form_openai_tools(tools, assistant_id: str):
@shared_task
def execute_chat_completion(assistant_id, thread_id, redis_channel, run_id):
try:
db = mongo_client[configs.mongo_database] # OpenAI call can fail, so we need to get the db again

oai_client = OpenAI(
base_url=configs.litellm_url,
api_key=os.getenv("LITELLM_MASTER_KEY"), # point to litellm server
)
db = mongo_client[configs.mongo_database]

# Fetch assistant and thread messages synchronously
assistant = db.assistants.find_one({"id": assistant_id})
Expand Down Expand Up @@ -459,6 +457,8 @@ def execute_chat_completion(assistant_id, thread_id, redis_channel, run_id):
@app.task
def execute_asst_file_create(file_id: str, assistant_id: str):
try:
if mongo_client is None:
raise Exception("MongoDB client not initialized yet")
db = mongo_client[configs.mongo_database]
collection_name = assistant_id
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
Expand Down
2 changes: 1 addition & 1 deletion core/tools/knowledge/file_knowledge_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Third Party
import requests

vector_db_url = f"{configs.vector_db_url}/similarity_search"
vector_db_url = f"{configs.vector_db_url}/similarity_match"

class FileKnowledgeTool:
name = "FileKnowledge"
Expand Down
1 change: 1 addition & 0 deletions core/tools/knowledge/vector_db/milvus/query_milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def __init__(
if alias is not None:
self.alias = alias
elif connection_args is not None:
connection_args = DEFAULT_MILVUS_CONNECTION
self.alias = Milvus.create_connection_alias(connection_args)
else:
raise ValueError('alias or connection_args must be passed to Milvus construtor')
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ services:
- EMBEDDING_HOST=text-embedding-api
- VECTOR_DB_HOST=vector-db-api
- MILVUS_HOST=milvus
- LITELLM_MASTER_KEY=abc
depends_on:
- redis
- mongodb
Expand All @@ -134,6 +135,7 @@ services:
- REDIS_HOST=redis
- MONGODB_HOST=mongodb
- LITELLM_HOST=litellm
- MILVUS_HOST=milvus
ports:
- '8000:8000'
depends_on:
Expand Down
14 changes: 11 additions & 3 deletions services/backend/api_server/app/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@

# MongoDB Configurationget
LITELLM_URL = configs.litellm_url
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "")
LITELLM_MASTER_KEY = os.getenv("LITELLM_MASTER_KEY", "abc") # Litellm fails without this key
HEADERS = {"accept": "application/json", "Content-Type": "application/json"}

# Initialize MongoDB client
mongo_client = AsyncIOMotorClient(configs.mongo_url, server_api=ServerApi("1"))
database = mongo_client[configs.mongo_database]

celery_app = Celery(configs.redis_url)
celery_app = Celery(broker=configs.redis_url)

redis = aioredis.from_url(configs.redis_url)
redis = aioredis.from_url(configs.redis_url, encoding="utf-8", decode_responses=True)

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -764,6 +764,14 @@ async def redis_subscriber(channel, timeout=1):
pubsub = redis.pubsub()
await pubsub.subscribe(channel)

# Check if the subscription was successful
channels = await redis.pubsub_channels()
logging.info(f"Channels: {channels}")
if channel in channels:
logging.info(f"Successfully subscribed to channel: {channel}")
else:
logging.error(f"Failed to subscribe to channel: {channel}")

while True:
try:
message = await asyncio.wait_for(
Expand Down

0 comments on commit 6fdca85

Please sign in to comment.