Skip to content

Commit 6a64d6a

Browse files
taskingaijcjameszyao
authored andcommitted
test: add test for assistant and inference with dict
1 parent 81c824a commit 6a64d6a

File tree

11 files changed

+629
-372
lines changed

11 files changed

+629
-372
lines changed

Diff for: taskingai/client/models/entities/action.py

-4
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,7 @@ class Action(BaseModel):
3131
description: str = Field(..., min_length=1, max_length=512)
3232
url: str = Field(...)
3333
method: ActionMethod = Field(...)
34-
path_param_schema: Optional[Dict[str, ActionParam]] = Field(None)
35-
query_param_schema: Optional[Dict[str, ActionParam]] = Field(None)
3634
body_type: ActionBodyType = Field(...)
37-
body_param_schema: Optional[Dict[str, ActionParam]] = Field(None)
38-
function_def: ChatCompletionFunction = Field(...)
3935
openapi_schema: Dict[str, Any] = Field(...)
4036
authentication: ActionAuthentication = Field(...)
4137
updated_timestamp: int = Field(..., ge=0)

Diff for: taskingai/inference/chat_completion.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,25 @@ def _validate_chat_completion_params(
5959
def _validate_message(msg: Union[SystemMessage, UserMessage, AssistantMessage, FunctionMessage, Dict[str, Any]]):
6060
if isinstance(msg, Dict):
6161
if msg["role"] == ChatCompletionRole.SYSTEM.value:
62+
msg.pop("role")
6263
return SystemMessage(**msg)
6364
elif msg["role"] == ChatCompletionRole.USER.value:
65+
msg.pop("role")
6466
return UserMessage(**msg)
6567
elif msg["role"] == ChatCompletionRole.ASSISTANT.value:
68+
msg.pop("role")
6669
return AssistantMessage(**msg)
6770
elif msg["role"] == ChatCompletionRole.FUNCTION.value:
71+
msg.pop("role")
6872
return FunctionMessage(**msg)
6973
else:
7074
raise ValueError("Invalid message role.")
7175

7276
elif (
73-
isinstance(msg, SystemMessage)
74-
or isinstance(msg, UserMessage)
75-
or isinstance(msg, AssistantMessage)
76-
or isinstance(msg, FunctionMessage)
77+
isinstance(msg, ChatCompletionSystemMessage)
78+
or isinstance(msg, ChatCompletionUserMessage)
79+
or isinstance(msg, ChatCompletionAssistantMessage)
80+
or isinstance(msg, ChatCompletionFunctionMessage)
7781
):
7882
return msg
7983

Diff for: test/common/utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,14 @@ def assume_assistant_result(assistant_dict: dict, res: dict):
135135
for key, value in assistant_dict.items():
136136
if key == 'system_prompt_template' and isinstance(value, str):
137137
pytest.assume(res[key] == [assistant_dict[key]])
138-
elif key in ["memory", "tool", "retrievals"]:
138+
elif key in ['retrieval_configs']:
139+
if isinstance(value, dict):
140+
pytest.assume(vars(res[key]) == assistant_dict[key])
141+
else:
142+
pytest.assume(res[key] == assistant_dict[key])
143+
elif key in ["memory", "tools", "retrievals"]:
139144
continue
140145
else:
141-
if key == 'retrieval_configs':
142-
res[key] = vars(res[key])
143146
pytest.assume(res[key] == assistant_dict[key])
144147

145148

Diff for: test/testcase/test_async/test_async_assistant.py

+49-23
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from taskingai.assistant import *
44
from taskingai.client.models import ToolRef, ToolType, RetrievalRef, RetrievalType, RetrievalConfig
5-
from taskingai.assistant.memory import AssistantNaiveMemory
5+
from taskingai.assistant.memory import AssistantNaiveMemory, AssistantZeroMemory
66
from test.config import Config
77
from test.common.logger import logger
88
from test.common.utils import list_to_dict
@@ -13,16 +13,6 @@
1313
@pytest.mark.test_async
1414
class TestAssistant(Base):
1515

16-
retrieval_configs_list = [
17-
{"method": "memory", "top_k": 2, "max_tokens": 4000},
18-
RetrievalConfig(
19-
method="memory",
20-
top_k=1,
21-
max_tokens=5000,
22-
23-
)
24-
]
25-
2616
@pytest.mark.run(order=51)
2717
@pytest.mark.asyncio
2818
async def test_a_create_assistant(self):
@@ -62,7 +52,11 @@ async def test_a_create_assistant(self):
6252
}
6353
for i in range(4):
6454
if i == 0:
55+
assistant_dict.update({"memory": {"type": "naive"}})
56+
assistant_dict.update({"retrievals": [{"type": "collection", "id": self.collection_id}]})
6557
assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}})
58+
assistant_dict.update({"tools": [{"type": "action", "id": self.action_id},
59+
{"type": "plugin", "id": "open_weather/get_hourly_forecast"}]})
6660
res = await a_create_assistant(**assistant_dict)
6761
res_dict = vars(res)
6862
logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}')
@@ -106,22 +100,54 @@ async def test_a_get_assistant(self):
106100

107101
@pytest.mark.run(order=54)
108102
@pytest.mark.asyncio
109-
@pytest.mark.parametrize("retrieval_configs", retrieval_configs_list)
110-
async def test_a_update_assistant(self, retrieval_configs):
103+
async def test_a_update_assistant(self):
111104

112105
# Update an assistant.
113106

114-
name = "openai"
115-
description = "test for openai"
107+
update_data_list = [
108+
{
109+
"name": "openai",
110+
"description": "test for openai",
111+
"memory": AssistantZeroMemory(),
112+
"retrievals": [
113+
RetrievalRef(
114+
type=RetrievalType.COLLECTION,
115+
id=self.collection_id,
116+
),
117+
],
118+
"retrieval_configs": RetrievalConfig(
119+
method="memory",
120+
top_k=2,
121+
max_tokens=4000,
116122

117-
res = await a_update_assistant(assistant_id=self.assistant_id, name=name, description=description, retrieval_configs=retrieval_configs)
118-
res_dict = vars(res)
119-
pytest.assume(res_dict["name"] == name)
120-
pytest.assume(res_dict["description"] == description)
121-
if isinstance(retrieval_configs, dict):
122-
pytest.assume(vars(res_dict["retrieval_configs"]) == retrieval_configs)
123-
else:
124-
pytest.assume(res_dict["retrieval_configs"] == retrieval_configs)
123+
),
124+
"tools": [
125+
ToolRef(
126+
type=ToolType.ACTION,
127+
id=self.action_id,
128+
),
129+
ToolRef(
130+
type=ToolType.PLUGIN,
131+
id="open_weather/get_hourly_forecast",
132+
)
133+
]
134+
},
135+
{
136+
"name": "openai",
137+
"description": "test for openai",
138+
"memory": {"type": "naive"},
139+
"retrievals": [{"type": "collection", "id": self.collection_id}],
140+
"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000},
141+
"tools": [{"type": "action", "id": self.action_id},
142+
{"type": "plugin", "id": "open_weather/get_hourly_forecast"}]
143+
144+
}
145+
]
146+
for update_data in update_data_list:
147+
res = await a_update_assistant(assistant_id=self.assistant_id, **update_data)
148+
res_dict = vars(res)
149+
logger.info(f'response_dict:{res_dict}, except_dict:{update_data}')
150+
assume_assistant_result(update_data, res_dict)
125151

126152
@pytest.mark.run(order=66)
127153
@pytest.mark.asyncio

0 commit comments

Comments
 (0)