Skip to content

Commit

Permalink
Remove authentication app, prepare for keycloak, add unit tests for a…
Browse files Browse the repository at this point in the history
…i agents
  • Loading branch information
mbertrand committed Dec 18, 2024
1 parent 253f015 commit 863835e
Show file tree
Hide file tree
Showing 41 changed files with 481 additions and 827 deletions.
2 changes: 1 addition & 1 deletion ai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def get_completion(self, message: str, *, debug: bool = settings.AI_DEBUG) -> st
yield f"\n\n<!-- {self.get_comment_metadata()} -->\n\n"


class SearchAgent(BaseChatAgent):
class RecommendationAgent(BaseChatAgent):
"""Service class for the AI search function agent"""

JOB_ID = "SEARCH_JOB"
Expand Down
31 changes: 10 additions & 21 deletions ai_agents/agents_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

import pytest
from django.conf import settings
from llama_index.core.base.llms.types import MessageRole
from llama_index.core.constants import DEFAULT_TEMPERATURE

from ai_agents.agents import SearchAgent
from ai_agents.factories import ChatMessageFactory
from ai_agents.agents import RecommendationAgent
from main.test_utils import assert_json_equal


Expand All @@ -21,15 +19,6 @@ def ai_settings(settings):
return settings


@pytest.fixture
def chat_history():
"""Return one round trip chat history for testing."""
return [
ChatMessageFactory(role=MessageRole.USER),
ChatMessageFactory(role=MessageRole.ASSISTANT),
]


@pytest.fixture
def search_results():
"""Return search results for testing."""
Expand All @@ -47,10 +36,10 @@ def search_results():
],
)
def test_search_agent_service_initialization_defaults(model, temperature, instructions):
"""Test the SearchAgent class instantiation."""
"""Test the RecommendationAgent class instantiation."""
name = "My search agent"

search_agent = SearchAgent(
search_agent = RecommendationAgent(
"user",
name=name,
model=model,
Expand All @@ -71,10 +60,10 @@ def test_search_agent_service_initialization_defaults(model, temperature, instru


def test_clear_chat_history(client, user, chat_history):
"""Test that the SearchAgent clears chat_history."""
search_agent = SearchAgent(user.username)
"""Test that the RecommendationAgent clears chat_history."""
search_agent = RecommendationAgent(user.username)
search_agent.agent.chat_history.extend(chat_history)
assert len(search_agent.agent.chat_history) == 2
assert len(search_agent.agent.chat_history) == 4
search_agent.clear_chat_history()
assert search_agent.agent.chat_history == []

Expand Down Expand Up @@ -104,7 +93,7 @@ def test_search_agent_tool(settings, mocker, search_results):
"ai_agents.agents.requests.get",
return_value=mocker.Mock(json=mocker.Mock(return_value=search_results)),
)
search_agent = SearchAgent("anonymous", name="test agent")
search_agent = RecommendationAgent("anonymous", name="test agent")
search_parameters = {
"q": "physics",
"resource_type": ["course", "program"],
Expand All @@ -124,21 +113,21 @@ def test_search_agent_tool(settings, mocker, search_results):
@pytest.mark.django_db
@pytest.mark.parametrize("debug", [True, False])
def test_get_completion(settings, mocker, debug, search_results):
"""Test that the SearchAgent get_completion method returns expected values."""
"""Test that the RecommendationAgent get_completion method returns expected values."""
settings.AI_DEBUG = debug
metadata = {
"metadata": {
"search_parameters": {"q": "physics"},
"search_results": search_results.get("results"),
"system_prompt": SearchAgent.INSTRUCTIONS,
"system_prompt": RecommendationAgent.INSTRUCTIONS,
}
}
expected_return_value = [b"Here ", b"are ", b"some ", b"results"]
mocker.patch(
"ai_agents.agents.OpenAIAgent.stream_chat",
return_value=mocker.Mock(response_gen=iter(expected_return_value)),
)
search_agent = SearchAgent("anonymous", name="test agent")
search_agent = RecommendationAgent("anonymous", name="test agent")
search_agent.search_parameters = metadata["metadata"]["search_parameters"]
search_agent.search_results = metadata["metadata"]["search_results"]
search_agent.instructions = metadata["metadata"]["system_prompt"]
Expand Down
17 changes: 11 additions & 6 deletions ai_agents/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import pytest
from llama_cloud import MessageRole

from ai_agents.factories import ChatMessageFactory


@pytest.fixture
def mock_search_agent_service(mocker):
"""Mock the SearchAgentService class."""
return mocker.patch(
"ai_agents.views.SearchAgentService",
autospec=True,
)
def chat_history():
"""Return one round trip chat history for testing."""
return [
ChatMessageFactory.create(role=MessageRole.USER),
ChatMessageFactory.create(role=MessageRole.ASSISTANT),
ChatMessageFactory.create(role=MessageRole.USER),
ChatMessageFactory.create(role=MessageRole.ASSISTANT),
]
17 changes: 11 additions & 6 deletions ai_agents/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,28 @@
from channels.generic.websocket import AsyncWebsocketConsumer
from llama_index.core.base.llms.types import ChatMessage

from ai_agents.agents import RecommendationAgent
from ai_agents.serializers import ChatRequestSerializer

log = logging.getLogger(__name__)


class RecommendationAgentConsumer(AsyncWebsocketConsumer):
"""
Async websocket consumer for the recommendation agent.
"""

async def connect(self):
"""Connect to the websocket and initialize the AI agent."""
user = self.scope.get("user", None)
self.username = user.username if user else "anonymous"
log.info("Username is %s", self.username)
from ai_agents.agents import SearchAgent
self.user_id = user.username if user else "anonymous"
log.info("Username is %s", self.user_id)

self.agent = SearchAgent(self.username)
self.agent = RecommendationAgent(self.user_id)
await super().connect()

async def receive(self, text_data: str) -> str:
"""Send the message to the AI agent and return its response."""
from ai_agents.serializers import ChatRequestSerializer

try:
text_data_json = json.loads(text_data)
Expand All @@ -33,7 +38,7 @@ async def receive(self, text_data: str) -> str:
model = serializer.validated_data.pop("model", None)

if clear_history:
self.agent.agent.chat_history.clear()
self.agent.clear_chat_history()
if model:
self.agent.agent.agent_worker._llm.model = model # noqa: SLF001
if temperature:
Expand Down
125 changes: 125 additions & 0 deletions ai_agents/consumers_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
"""Tests for ai_agents consumers"""

import json
from random import randint

import pytest
from llama_cloud import MessageRole
from llama_index.core.constants import DEFAULT_TEMPERATURE

from ai_agents import consumers
from ai_agents.agents import RecommendationAgent
from ai_agents.factories import ChatMessageFactory
from main.factories import UserFactory


@pytest.fixture(autouse=True)
def mock_connect(mocker):
"""Mock the AsyncWebsocketConsumer connect function"""
return mocker.patch(
"ai_agents.consumers.AsyncWebsocketConsumer.connect",
new_callable=mocker.AsyncMock,
)


@pytest.fixture(autouse=True)
def mock_send(mocker):
"""Mock the AsyncWebsocketConsumer connect function"""
return mocker.patch(
"ai_agents.consumers.AsyncWebsocketConsumer.send", new_callable=mocker.AsyncMock
)


@pytest.fixture
def agent_user():
"""Return a user for the agent."""
return UserFactory.build(
username=f"test_user_{randint(1, 1000)}" # noqa: S311
)


@pytest.fixture
def recommendation_consumer(agent_user):
"""Return a recommendation consumer."""
consumer = consumers.RecommendationAgentConsumer()
consumer.scope = {"user": agent_user}
return consumer


async def test_recommend_agent_connect(
recommendation_consumer, agent_user, mock_connect
):
"""Test the connect function of the recommendation agent."""
await recommendation_consumer.connect()

assert mock_connect.call_count == 1
assert recommendation_consumer.user_id == agent_user.username
assert recommendation_consumer.agent.user_id == agent_user.username


@pytest.mark.parametrize(
("message", "temperature", "instructions", "model"),
[
("hello", 0.7, "Answer this question as best you can", "gpt-3.5-turbo"),
("hello", 0.7, "", "gpt-3.5-turbo"),
("hello", 0.6, None, "gpt-4-turbo"),
("hello", 0.4, None, "gpt-4o"),
("hello", 0.4, None, ""),
("hello", None, None, None),
],
)
async def test_recommend_agent_receive( # noqa: PLR0913
settings,
mocker,
mock_send,
recommendation_consumer,
message,
temperature,
instructions,
model,
):
"""Test the receive function of the recommendation agent."""
response = ChatMessageFactory.create(role=MessageRole.ASSISTANT)
mock_completion = mocker.patch(
"ai_agents.agents.RecommendationAgent.get_completion",
side_effect=[response.content.split(" ")],
)
data = {
"message": message,
}
if temperature:
data["temperature"] = temperature
if instructions is not None:
data["instructions"] = instructions
if model is not None:
data["model"] = model
await recommendation_consumer.connect()
await recommendation_consumer.receive(json.dumps(data))

assert recommendation_consumer.agent.user_id.startswith("test_user")
assert recommendation_consumer.agent.agent.agent_worker._llm.temperature == ( # noqa: SLF001
temperature if temperature else DEFAULT_TEMPERATURE
)
assert recommendation_consumer.agent.agent.agent_worker.prefix_messages[
0
].content == (instructions if instructions else RecommendationAgent.INSTRUCTIONS)
assert recommendation_consumer.agent.agent.agent_worker._llm.model == ( # noqa: SLF001
model if model else settings.AI_MODEL
)

mock_completion.assert_called_once_with(message)
assert mock_send.call_count == len(response.content.split(" ")) + 1
mock_send.assert_any_call(text_data="!endResponse")


@pytest.mark.parametrize("clear_history", [True, False])
async def test_clear_history(mocker, clear_history, recommendation_consumer):
"""Test the clear history function of the recommendation agent."""
mock_clear = mocker.patch(
"ai_agents.consumers.RecommendationAgent.clear_chat_history"
)
await recommendation_consumer.connect()
await recommendation_consumer.receive(
json.dumps({"clear_history": clear_history, "message": "hello"})
)
assert mock_clear.call_count == (1 if clear_history else 0)
14 changes: 11 additions & 3 deletions ai_agents/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@


class AIProxy(ABC):
"""Abstract base helper class for an AI proxy/gateway."""
"""
Abstract base helper class for an AI proxy/gateway.
"""

REQUIRED_SETTINGS = []
REQUIRED_SETTINGS = ["AI_PROXY_URL", "AI_PROXY_AUTH_TOKEN", "AI_PROXY_CLASS"]

def __init__(self):
"""Raise an error if required settings are missing."""
Expand Down Expand Up @@ -48,12 +50,18 @@ class LiteLLMProxy(AIProxy):
REQUIRED_SETTINGS = ("AI_PROXY_URL", "AI_PROXY_AUTH_TOKEN")

def get_api_kwargs(self) -> dict:
"""
Return the api kwargs required to connect to the proxy.
"""
return {
"api_base": settings.AI_PROXY_URL,
"api_key": settings.AI_PROXY_AUTH_TOKEN,
}

def get_additional_kwargs(self, service: BaseChatAgent) -> dict:
"""
Return any additional kwargs that should be sent to the proxy.
"""
return {
"user": service.user_id,
"store": True,
Expand All @@ -67,7 +75,7 @@ def get_additional_kwargs(self, service: BaseChatAgent) -> dict:
},
}

def create_proxy_user(self, user_id, endpoint="new") -> None:
def create_proxy_user(self, user_id: str, endpoint: str = "new") -> None:
"""
Set the rate limit for the user by creating a LiteLLM customer account.
Anonymous users will share the same rate limit.
Expand Down
6 changes: 6 additions & 0 deletions ai_agents/proxy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Tests for ai_agents proxy functionality"""


def test_litellm_create_user():
"""Test that correct api calls are made to create a LitelLM proxy user"""
assert True
4 changes: 2 additions & 2 deletions ai_agents/routing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from django.urls import re_path

from . import consumers
from ai_agents import consumers

websocket_urlpatterns = [
# other websocket URLs here
# websocket URLs go here
re_path(
r"ws/recommendation_agent/",
consumers.RecommendationAgentConsumer.as_asgi(),
Expand Down
16 changes: 12 additions & 4 deletions ai_agents/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
class ChatRequestSerializer(serializers.Serializer):
"""DRF serializer for chatbot requests"""

message = serializers.CharField(allow_blank=False)
model = serializers.CharField(default=settings.AI_MODEL, required=False)
temperature = serializers.FloatField(min_value=0.0, max_value=1.0, required=False)
message = serializers.CharField(required=True, allow_blank=False)
model = serializers.CharField(
default=settings.AI_MODEL, required=False, allow_blank=True
)
temperature = serializers.FloatField(
min_value=0.0,
max_value=1.0,
required=False,
)
instructions = serializers.CharField(required=False, allow_blank=True)
clear_history = serializers.BooleanField(default=False)
clear_history = serializers.BooleanField(
default=False,
)
17 changes: 15 additions & 2 deletions ai_agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,22 @@
import logging
from enum import Enum

from named_enum import ExtendedEnum

log = logging.getLogger(__name__)


def enum_zip(label: str, enum: Enum) -> type[Enum]:
"""Create a new Enum from a tuple of Enum names"""
def enum_zip(label: str, enum: ExtendedEnum) -> type[Enum]:
"""
Create a new Enum with both name and value equal to
the names of the given ExtendedEnum.
Args:
label: The label for the new Enum
enum: The Enum to use as a basis for the new Enum
Returns:
A new Enum with the names of the given Enum as both name and value
"""
return Enum(label, dict(zip(enum.names(), enum.names())))
Loading

0 comments on commit 863835e

Please sign in to comment.