Skip to content

Commit

Permalink
More cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertrand committed Dec 18, 2024
1 parent 1b923c8 commit 253f015
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ jobs:
export MEDIA_ROOT="$(mktemp -d)"
./scripts/test/python_tests.sh
env:
DATABASE_URL: postgres://postgres:postgres@localhost:5432/postgres # pragma: allowlist secret
DATABASE_URL: postgres://postgres:postgres@localhost:5433/postgres # pragma: allowlist secret
MITOL_SECURE_SSL_REDIRECT: "False"
MITOL_DB_DISABLE_SSL: "True"
MITOL_FEATURES_DEFAULT: "True"
Expand Down
7 changes: 6 additions & 1 deletion ai_agents/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,21 @@ class BaseChatAgent(ABC):

def __init__(
self,
user_id: str,
*,
name: str = "AI Chat Agent",
model: Optional[str] = None,
temperature: Optional[float] = None,
instructions: Optional[str] = None,
):
"""Initialize the AI chat agent service"""
self.user_id = user_id
self.assistant_name = name
self.ai = settings.AI_MODEL_API
self.model = model or settings.AI_MODEL
self.temperature = temperature or DEFAULT_TEMPERATURE
self.instructions = instructions or self.INSTRUCTIONS
if settings.AI_PROXY_CLASS:
if settings.AI_PROXY_CLASS and settings.AI_PROXY_URL:
self.proxy = import_string(f"ai_agents.proxy.{settings.AI_PROXY_CLASS}")()
else:
self.proxy = None
Expand Down Expand Up @@ -275,6 +277,7 @@ class SearchToolSchema(pydantic.BaseModel):

def __init__(
self,
user_id: str,
*,
name: Optional[str] = "Learning Resource Search AI Assistant",
model: Optional[str] = None,
Expand All @@ -283,13 +286,15 @@ def __init__(
):
"""Initialize the AI search agent service"""
super().__init__(
user_id,
name=name,
model=model or settings.AI_MODEL,
temperature=temperature,
instructions=instructions,
)
self.search_parameters = []
self.search_results = []

self.agent = self.create_agent()
self.create_agent()

Expand Down
7 changes: 4 additions & 3 deletions ai_agents/agents_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def test_search_agent_service_initialization_defaults(model, temperature, instru
name = "My search agent"

search_agent = SearchAgent(
"user",
name=name,
model=model,
temperature=temperature,
Expand All @@ -71,7 +72,7 @@ def test_search_agent_service_initialization_defaults(model, temperature, instru

def test_clear_chat_history(client, user, chat_history):
"""Test that the SearchAgent clears chat_history."""
search_agent = SearchAgent()
search_agent = SearchAgent(user.username)
search_agent.agent.chat_history.extend(chat_history)
assert len(search_agent.agent.chat_history) == 2
search_agent.clear_chat_history()
Expand Down Expand Up @@ -103,7 +104,7 @@ def test_search_agent_tool(settings, mocker, search_results):
"ai_agents.agents.requests.get",
return_value=mocker.Mock(json=mocker.Mock(return_value=search_results)),
)
search_agent = SearchAgent(name="test agent")
search_agent = SearchAgent("anonymous", name="test agent")
search_parameters = {
"q": "physics",
"resource_type": ["course", "program"],
Expand Down Expand Up @@ -137,7 +138,7 @@ def test_get_completion(settings, mocker, debug, search_results):
"ai_agents.agents.OpenAIAgent.stream_chat",
return_value=mocker.Mock(response_gen=iter(expected_return_value)),
)
search_agent = SearchAgent(name="test agent")
search_agent = SearchAgent("anonymous", name="test agent")
search_agent.search_parameters = metadata["metadata"]["search_parameters"]
search_agent.search_results = metadata["metadata"]["search_results"]
search_agent.instructions = metadata["metadata"]["system_prompt"]
Expand Down
2 changes: 1 addition & 1 deletion ai_agents/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def connect(self):
log.info("Username is %s", self.username)
from ai_agents.agents import SearchAgent

self.agent = SearchAgent()
self.agent = SearchAgent(self.username)
await super().connect()

async def receive(self, text_data: str) -> str:
Expand Down

0 comments on commit 253f015

Please sign in to comment.