Skip to content

Commit

Permalink
added unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
tanweersalah committed Feb 20, 2025
1 parent dd397df commit 174610b
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 20 deletions.
5 changes: 0 additions & 5 deletions src/agents/common/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ def _subtask_selector_node(self, state: BaseAgentState) -> dict[str, Any]:
)
],
IS_LAST_STEP: True,
"is_last": True,
}

async def _invoke_chain(self, state: BaseAgentState, config: RunnableConfig) -> Any:
Expand All @@ -144,10 +143,6 @@ async def _model_node(
self, state: BaseAgentState, config: RunnableConfig
) -> dict[str, Any]:
try:
if state.my_task and state.my_task.assigned_to == "KubernetesAgent":
raise Exception(
"This is test exception ... need to be removed before code push"
)
response = await self._invoke_chain(state, config)
except Exception as e:
error_message = f"An error occurred while processing the request: {e}"
Expand Down
7 changes: 5 additions & 2 deletions src/agents/common/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,14 @@ def completed(self) -> bool:
"""Check if the task is completed."""
return self.status == SubTaskStatus.COMPLETED

def pending(self) -> bool:
def is_pending(self) -> bool:
"""Check if the task is pending."""
return self.status == SubTaskStatus.PENDING

def is_error(self) -> bool:
"""Check if the task is error status."""
return self.status == SubTaskStatus.ERROR


# After upgrading generative-ai-hub-sdk we can message that use pydantic v2
# Currently, we are using pydantic v1.
Expand Down Expand Up @@ -138,7 +142,6 @@ class BaseAgentState(BaseModel):
messages: Annotated[Sequence[BaseMessage], add_messages]
subtasks: list[SubTask] | None = []
k8s_client: IK8sClient
is_last: bool = False

# Subgraph private fields
agent_messages: Annotated[Sequence[BaseMessage], add_messages]
Expand Down
22 changes: 18 additions & 4 deletions src/agents/supervisor/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def decide_route_or_exit(state: SupervisorState) -> Literal[ROUTER, END]: # typ
def decide_entry_point(state: SupervisorState) -> Literal[PLANNER, ROUTER, FINALIZER]: # type: ignore
"""When entering the supervisor subgraph, decide the entry point: plan, route, or finalize."""

# if all subtasks are completed, finalize the response
if state.subtasks and all(not subtask.pending() for subtask in state.subtasks):
logger.debug("Finalizing as all subtasks are completed.")
# if no subtasks is pending, finalize the response
if state.subtasks and all(not subtask.is_pending() for subtask in state.subtasks):
logger.debug("Routing to Finilizer as no subtasks is pending.")
return FINALIZER

# if subtasks exists but not all are completed, router delegates to the next agent
Expand Down Expand Up @@ -125,7 +125,7 @@ def agent_node(self) -> CompiledGraph:
def _route(self, state: SupervisorState) -> dict[str, Any]:
"""Router node. Routes the conversation to the next agent."""
for subtask in state.subtasks:
if subtask.pending():
if subtask.is_pending():
next_agent = subtask.assigned_to
return {
"next": next_agent,
Expand Down Expand Up @@ -223,6 +223,20 @@ def _final_response_chain(self, state: SupervisorState) -> RunnableSequence:
async def _generate_final_response(self, state: SupervisorState) -> dict[str, Any]:
"""Generate the final response."""

# If all required agents failed: tell user that we can't give them response due to agent failure
if state.subtasks and all(subtask.is_error() for subtask in state.subtasks):
return {
MESSAGES: [
AIMessage(
content="We're unable to provide a response at this time due to agent failure. "
"Our team has been notified and is working to resolve the issue. "
"Please try again or reach out to our support team for further assistance.",
name=FINALIZER,
)
],
NEXT: END,
}

final_response_chain = self._final_response_chain(state)

final_response = await ainvoke_chain(
Expand Down
1 change: 1 addition & 0 deletions temp_test_run_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"testCases": [], "conversationalTestCases": [], "metricsScores": [], "runDuration": 0.0}
79 changes: 70 additions & 9 deletions tests/unit/agents/supervisor/test_supervisor_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def supervisor_agent(self, mock_models):
description="Task 2",
task_title="Task 2",
assigned_to=KYMA_AGENT,
status="in_progress",
status="pending",
)
],
[AIMessage(content="Fake message")],
Expand All @@ -87,7 +87,7 @@ def supervisor_agent(self, mock_models):
description="Task 2",
task_title="Task 2",
assigned_to=KYMA_AGENT,
status="in_progress",
status="pending",
)
],
None,
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_agent_route(

@pytest.mark.asyncio
@pytest.mark.parametrize(
"description, input_query, conversation_messages, final_response_content, expected_output, expected_error",
"description, input_query, conversation_messages, subtasks, final_response_content, expected_output, expected_error",
[
(
"Generates final response successfully",
Expand All @@ -155,6 +155,14 @@ def test_agent_route(
name="KubernetesAgent",
),
],
[
SubTask(
description="Explain Kyma function deployment",
task_title="Explain Kyma function deployment",
assigned_to=KYMA_AGENT,
status="completed",
),
],
"To deploy a Kyma function, follow these steps: "
"1. Create a function file. "
"2. Use the Kyma CLI to deploy. "
Expand Down Expand Up @@ -183,11 +191,57 @@ def test_agent_route(
name="KubernetesAgent",
),
],
[
SubTask(
description="Explain Kubernetes",
task_title="Explain Kubernetes",
assigned_to=K8S_AGENT,
status="completed",
),
],
"",
{
"messages": [AIMessage(content="", name="Finalizer")],
"next": "__end__",
},
"",
),
(
"Do not generate final response as all subtasks failed",
"What is Kubernetes? and what is KYMA",
[
HumanMessage(content="What is Kubernetes?"),
AIMessage(
content="Kubernetes is a container orchestration platform.",
name="KubernetesAgent",
),
],
[
SubTask(
description="Explain Kubernetes",
task_title="Explain Kubernetes",
assigned_to=K8S_AGENT,
status="error",
),
SubTask(
description="Explain Kyma",
task_title="Explain Kyma",
assigned_to=KYMA_AGENT,
status="error",
),
],
None, # this content should be handled by finalizer itself
{
"messages": [
AIMessage(
content="We're unable to provide a response at this time due to agent failure."
"Our team has been notified and is working to resolve the issue."
"Please try again or reach out to our support team for further assistance.",
name="Finalizer",
)
],
"next": "__end__",
},
None,
),
],
Expand All @@ -198,15 +252,19 @@ async def test_agent_generate_final_response(
description,
input_query,
conversation_messages,
subtasks,
final_response_content,
expected_output,
expected_error,
):
# Given
state = SupervisorState(messages=conversation_messages)
state = SupervisorState(messages=conversation_messages, subtasks=subtasks)

mock_final_response_chain = AsyncMock()
mock_final_response_chain.ainvoke.return_value.content = final_response_content
if final_response_content:
mock_final_response_chain.ainvoke.return_value.content = (
final_response_content
)

with patch.object(
supervisor_agent,
Expand All @@ -219,9 +277,10 @@ async def test_agent_generate_final_response(
# Then
assert result == expected_output

mock_final_response_chain.ainvoke.assert_called_once_with(
config=None, input={"messages": conversation_messages}
)
if final_response_content:
mock_final_response_chain.ainvoke.assert_called_once_with(
config=None, input={"messages": conversation_messages}
)

@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down Expand Up @@ -324,7 +383,9 @@ async def test_agent_plan(
else:
mock_invoke_planner.return_value = Plan.parse_raw(mock_plan_content)

state = SupervisorState(messages=[HumanMessage(content=input_query)])
state = SupervisorState(
messages=[HumanMessage(content=input_query)], subtasks=[]
)
result = await supervisor_agent._plan(state)

assert result == expected_output
Expand Down

0 comments on commit 174610b

Please sign in to comment.