From 253f015cd25b772b74abb6f78ba906b3c8343bc5 Mon Sep 17 00:00:00 2001 From: Matt Bertrand Date: Tue, 17 Dec 2024 22:08:42 -0500 Subject: [PATCH] More cleanup --- .github/workflows/ci.yml | 2 +- ai_agents/agents.py | 7 ++++++- ai_agents/agents_test.py | 7 ++++--- ai_agents/consumers.py | 2 +- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c5c270b..cca8340 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -61,7 +61,7 @@ jobs: export MEDIA_ROOT="$(mktemp -d)" ./scripts/test/python_tests.sh env: - DATABASE_URL: postgres://postgres:postgres@localhost:5432/postgres # pragma: allowlist secret + DATABASE_URL: postgres://postgres:postgres@localhost:5433/postgres # pragma: allowlist secret MITOL_SECURE_SSL_REDIRECT: "False" MITOL_DB_DISABLE_SSL: "True" MITOL_FEATURES_DEFAULT: "True" diff --git a/ai_agents/agents.py b/ai_agents/agents.py index 587204c..9aca631 100644 --- a/ai_agents/agents.py +++ b/ai_agents/agents.py @@ -48,6 +48,7 @@ class BaseChatAgent(ABC): def __init__( self, + user_id: str, *, name: str = "AI Chat Agent", model: Optional[str] = None, @@ -55,12 +56,13 @@ def __init__( instructions: Optional[str] = None, ): """Initialize the AI chat agent service""" + self.user_id = user_id self.assistant_name = name self.ai = settings.AI_MODEL_API self.model = model or settings.AI_MODEL self.temperature = temperature or DEFAULT_TEMPERATURE self.instructions = instructions or self.INSTRUCTIONS - if settings.AI_PROXY_CLASS: + if settings.AI_PROXY_CLASS and settings.AI_PROXY_URL: self.proxy = import_string(f"ai_agents.proxy.{settings.AI_PROXY_CLASS}")() else: self.proxy = None @@ -275,6 +277,7 @@ class SearchToolSchema(pydantic.BaseModel): def __init__( self, + user_id: str, *, name: Optional[str] = "Learning Resource Search AI Assistant", model: Optional[str] = None, @@ -283,6 +286,7 @@ def __init__( ): """Initialize the AI search agent service""" super().__init__( + user_id, name=name, model=model or settings.AI_MODEL, temperature=temperature, @@ -290,6 +294,7 @@ def __init__( ) self.search_parameters = [] self.search_results = [] + self.agent = self.create_agent() self.create_agent() diff --git a/ai_agents/agents_test.py b/ai_agents/agents_test.py index 515de51..c49507c 100644 --- a/ai_agents/agents_test.py +++ b/ai_agents/agents_test.py @@ -51,6 +51,7 @@ def test_search_agent_service_initialization_defaults(model, temperature, instru name = "My search agent" search_agent = SearchAgent( + "user", name=name, model=model, temperature=temperature, @@ -71,7 +72,7 @@ def test_search_agent_service_initialization_defaults(model, temperature, instru def test_clear_chat_history(client, user, chat_history): """Test that the SearchAgent clears chat_history.""" - search_agent = SearchAgent() + search_agent = SearchAgent(user.username) search_agent.agent.chat_history.extend(chat_history) assert len(search_agent.agent.chat_history) == 2 search_agent.clear_chat_history() @@ -103,7 +104,7 @@ def test_search_agent_tool(settings, mocker, search_results): "ai_agents.agents.requests.get", return_value=mocker.Mock(json=mocker.Mock(return_value=search_results)), ) - search_agent = SearchAgent(name="test agent") + search_agent = SearchAgent("anonymous", name="test agent") search_parameters = { "q": "physics", "resource_type": ["course", "program"], @@ -137,7 +138,7 @@ def test_get_completion(settings, mocker, debug, search_results): "ai_agents.agents.OpenAIAgent.stream_chat", return_value=mocker.Mock(response_gen=iter(expected_return_value)), ) - search_agent = SearchAgent(name="test agent") + search_agent = SearchAgent("anonymous", name="test agent") search_agent.search_parameters = metadata["metadata"]["search_parameters"] search_agent.search_results = metadata["metadata"]["search_results"] search_agent.instructions = metadata["metadata"]["system_prompt"] diff --git a/ai_agents/consumers.py b/ai_agents/consumers.py index e078ddb..5afe94b 100644 --- a/ai_agents/consumers.py +++ b/ai_agents/consumers.py @@ -15,7 +15,7 @@ async def connect(self): log.info("Username is %s", self.username) from ai_agents.agents import SearchAgent - self.agent = SearchAgent() + self.agent = SearchAgent(self.username) await super().connect() async def receive(self, text_data: str) -> str: