From 9c6202edbea4abaaf4c4ca4ab4ee6a426c045c4f Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Thu, 19 Dec 2024 09:14:38 -0500 Subject: [PATCH] Alterrnate SSE endpoint --- ai_agents/agents.py | 10 ++- ai_agents/consumers.py | 125 ++++++++++++++++++++++++++++++------ ai_agents/consumers_test.py | 2 +- ai_agents/routing.py | 14 +++- main/asgi.py | 8 ++- main/settings.py | 9 +++ poetry.lock | 23 ++++++- pyproject.toml | 1 + 8 files changed, 158 insertions(+), 34 deletions(-) diff --git a/ai_agents/agents.py b/ai_agents/agents.py index 70a8858..67b4103 100644 --- a/ai_agents/agents.py +++ b/ai_agents/agents.py @@ -10,9 +10,9 @@ from django.conf import settings from django.core.cache import caches from django.utils.module_loading import import_string -from llama_cloud import ChatMessage from llama_index.agent.openai import OpenAIAgent from llama_index.core.agent import AgentRunner +from llama_index.core.base.llms.types import ChatMessage from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.tools import FunctionTool, ToolMetadata from llama_index.llms.openai import OpenAI @@ -335,8 +335,6 @@ def __init__( ) self.search_parameters = [] self.search_results = [] - - self.agent = self.create_agent() self.create_agent() def search_courses(self, q: str, **kwargs) -> str: @@ -411,15 +409,15 @@ def create_openai_agent(self) -> OpenAIAgent: self.proxy.get_additional_kwargs(self) if self.proxy else {} ), ) - agent = OpenAIAgent.from_tools( + self.agent = OpenAIAgent.from_tools( tools=self.create_tools(), llm=llm, verbose=True, system_prompt=self.instructions, ) - if settings.AI_CACHE_HISTORY: + if self.save_history: self.get_or_create_chat_history_cache() - return agent + return self.agent def create_tools(self): """Create tools required by the agent""" diff --git a/ai_agents/consumers.py b/ai_agents/consumers.py index ae6f070..2157015 100644 --- a/ai_agents/consumers.py +++ b/ai_agents/consumers.py @@ -1,7 +1,10 @@ import json import logging +from channels.generic.http import AsyncHttpConsumer from channels.generic.websocket import AsyncWebsocketConsumer +from channels.layers import get_channel_layer +from django.utils.text import slugify from llama_index.core.base.llms.types import ChatMessage from ai_agents.agents import RecommendationAgent @@ -10,7 +13,33 @@ log = logging.getLogger(__name__) -class RecommendationAgentConsumer(AsyncWebsocketConsumer): +def process_message(message_json, agent) -> str: + """ + Validate the message, update the agent if necessary + """ + text_data_json = json.loads(message_json) + serializer = ChatRequestSerializer(data=text_data_json) + serializer.is_valid(raise_exception=True) + message_text = serializer.validated_data.pop("message", "") + clear_history = serializer.validated_data.pop("clear_history", False) + temperature = serializer.validated_data.pop("temperature", None) + instructions = serializer.validated_data.pop("instructions", None) + model = serializer.validated_data.pop("model", None) + + if clear_history: + agent.agent.clear_chat_history() + if model: + agent.agent.agent_worker._llm.model = model # noqa: SLF001 + if temperature: + agent.agent.agent_worker._llm.temperature = temperature # noqa: SLF001 + if instructions: + agent.agent.agent_worker.prefix_messages = [ + ChatMessage(content=instructions, role="system") + ] + return message_text + + +class RecommendationAgentWSConsumer(AsyncWebsocketConsumer): """ Async websocket consumer for the recommendation agent. """ @@ -36,25 +65,7 @@ async def receive(self, text_data: str) -> str: """Send the message to the AI agent and return its response.""" try: - text_data_json = json.loads(text_data) - serializer = ChatRequestSerializer(data=text_data_json) - serializer.is_valid(raise_exception=True) - message_text = serializer.validated_data.pop("message", "") - clear_history = serializer.validated_data.pop("clear_history", False) - temperature = serializer.validated_data.pop("temperature", None) - instructions = serializer.validated_data.pop("instructions", None) - model = serializer.validated_data.pop("model", None) - - if clear_history: - self.agent.clear_chat_history() - if model: - self.agent.agent.agent_worker._llm.model = model # noqa: SLF001 - if temperature: - self.agent.agent.agent_worker._llm.temperature = temperature # noqa: SLF001 - if instructions: - self.agent.agent.agent_worker.prefix_messages = [ - ChatMessage(content=instructions, role="system") - ] + message_text = process_message(text_data, self.agent) for chunk in self.agent.get_completion(message_text): await self.send(text_data=chunk) @@ -63,3 +74,77 @@ async def receive(self, text_data: str) -> str: finally: # This is a bit hacky, but it works for now await self.send(text_data="!endResponse") + + +class RecommendationAgentSSEConsumer(AsyncHttpConsumer): + async def handle(self, message: str): + user = self.scope.get("user", None) + session = self.scope.get("session", None) + + if user and user.username and user.username != "AnonymousUser": + self.user_id = user.username + elif session: + if not session.session_key: + session.save() + self.user_id = slugify(session.session_key)[:100] + else: + log.info("Anon user, no session") + self.user_id = "Anonymous" + + agent = RecommendationAgent(self.user_id) + + self.channel_layer = get_channel_layer() + self.room_name = "recommendation_bot" + self.room_group_name = f"recommendation_bot_{self.user_id}" + await self.channel_layer.group_add( + f"recommendation_bot_{self.user_id}", self.channel_name + ) + + await self.send_headers( + headers=[ + (b"Cache-Control", b"no-cache"), + ( + b"Content-Type", + b"text/event-stream", + ), + ( + b"Transfer-Encoding", + b"chunked", + ), + (b"Connection", b"keep-alive"), + ] + ) + # Headers are only sent after the first body event. + # Set "more_body" to tell the interface server to not + # finish the response yet: + payload = "\nevent: ping", "data: null\n\n\n" + await self.send_body(payload.encode("utf-8"), more_body=True) + + try: + message_text = process_message(message, agent) + + for chunk in agent.get_completion(message_text): + await self.send_event(event=chunk) + except: # noqa: E722 + log.exception("Error in RecommendationAgentConsumer") + finally: + self.disconnect() + + async def disconnect(self): + await self.channel_layer.group_discard(f"sse_{self.user_id}", self.channel_name) + + async def send_event(self, event: str): + # Send response event + log.info(event) + data = f"event: agent_response\ndata: {event}\n\n" + await self.send_body(data.encode("utf-8"), more_body=True) + + async def http_request(self, message): + """ + Receives an SSE request and holds the connection open + until the client or server chooses to disconnect. + """ + try: + await self.handle(message.get("body")) + finally: + pass diff --git a/ai_agents/consumers_test.py b/ai_agents/consumers_test.py index f08cc17..eee481e 100644 --- a/ai_agents/consumers_test.py +++ b/ai_agents/consumers_test.py @@ -41,7 +41,7 @@ def agent_user(): @pytest.fixture def recommendation_consumer(agent_user): """Return a recommendation consumer.""" - consumer = consumers.RecommendationAgentConsumer() + consumer = consumers.RecommendationAgentWSConsumer() consumer.scope = {"user": agent_user} return consumer diff --git a/ai_agents/routing.py b/ai_agents/routing.py index d946370..aa95e43 100644 --- a/ai_agents/routing.py +++ b/ai_agents/routing.py @@ -2,11 +2,19 @@ from ai_agents import consumers -websocket_urlpatterns = [ +websocket_patterns = [ # websocket URLs go here re_path( r"ws/recommendation_agent/", - consumers.RecommendationAgentConsumer.as_asgi(), - name="recommendation_agent", + consumers.RecommendationAgentWSConsumer.as_asgi(), + name="recommendation_agent_ws", + ), +] + +http_patterns = [ + re_path( + r"sse/recommendation_agent/", + consumers.RecommendationAgentSSEConsumer.as_asgi(), + name="recommendation_agent_sse", ), ] diff --git a/main/asgi.py b/main/asgi.py index 3a4c428..8b3b750 100644 --- a/main/asgi.py +++ b/main/asgi.py @@ -6,13 +6,15 @@ os.environ.setdefault("DJANGO_SETTINGS_MODULE", "main.settings") -import ai_agents.routing +django_asgi_app = get_asgi_application() + +import ai_agents.routing # noqa: E402 application = ProtocolTypeRouter( { - "http": get_asgi_application(), + "http": AuthMiddlewareStack(URLRouter(ai_agents.routing.http_patterns)), "websocket": AuthMiddlewareStack( - URLRouter(ai_agents.routing.websocket_urlpatterns) + URLRouter(ai_agents.routing.websocket_patterns) ), } ) diff --git a/main/settings.py b/main/settings.py index 77bea5d..86c05cb 100644 --- a/main/settings.py +++ b/main/settings.py @@ -548,6 +548,15 @@ def get_all_config_keys(): KEYCLOAK_ADMIN_SECURE = get_bool("KEYCLOAK_ADMIN_SECURE", True) # noqa: FBT003 +CHANNEL_LAYERS = { + "default": { + "BACKEND": "channels_redis.core.RedisChannelLayer", + "CONFIG": { + "hosts": [("redis", 6379)], + }, + }, +} + # AI settings AI_DEBUG = get_bool("AI_DEBUG", False) # noqa: FBT003 AI_CACHE = get_string(name="AI_CACHE", default="redis") diff --git a/poetry.lock b/poetry.lock index 36be86e..78f44a0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -684,6 +684,27 @@ Django = ">=4.2" daphne = ["daphne (>=4.0.0)"] tests = ["async-timeout", "coverage (>=4.5,<5.0)", "pytest", "pytest-asyncio", "pytest-django"] +[[package]] +name = "channels-redis" +version = "4.2.1" +description = "Redis-backed ASGI channel layer implementation" +optional = false +python-versions = ">=3.8" +files = [ + {file = "channels_redis-4.2.1-py3-none-any.whl", hash = "sha256:2ca33105b3a04b5a327a9c47dd762b546f30b76a0cd3f3f593a23d91d346b6f4"}, + {file = "channels_redis-4.2.1.tar.gz", hash = "sha256:8375e81493e684792efe6e6eca60ef3d7782ef76c6664057d2e5c31e80d636dd"}, +] + +[package.dependencies] +asgiref = ">=3.2.10,<4" +channels = "*" +msgpack = ">=1.0,<2.0" +redis = ">=4.6" + +[package.extras] +cryptography = ["cryptography (>=1.3.0)"] +tests = ["async-timeout", "cryptography (>=1.3.0)", "pytest", "pytest-asyncio", "pytest-timeout"] + [[package]] name = "charset-normalizer" version = "3.4.0" @@ -6224,4 +6245,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "3.12.6" -content-hash = "3fc1d45c50905ef730ee7f95e06a112593509ec512e3479b1a126a6cfcf3a279" +content-hash = "3bda5fae5c354d0d77bd5648d66a5820595e5b629a48d0f17b73e8fdda4a72c3" diff --git a/pyproject.toml b/pyproject.toml index dc1a9c2..e63f242 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ uvicorn = {extras = ["standard"], version = "^0.32.1"} django-guardian = "^2.4.0" named-enum = "^1.4.0" ulid-py = "^0.2.0" +channels-redis = "^4.2.1" [tool.poetry.group.dev.dependencies] bpython = "^0.24"