|
2 | 2 |
|
3 | 3 | from taskingai.assistant import *
|
4 | 4 | from taskingai.client.models import ToolRef, ToolType, RetrievalRef, RetrievalType, RetrievalConfig
|
5 |
| -from taskingai.assistant.memory import AssistantNaiveMemory |
| 5 | +from taskingai.assistant.memory import AssistantNaiveMemory, AssistantZeroMemory |
6 | 6 | from test.config import Config
|
7 | 7 | from test.common.logger import logger
|
8 | 8 | from test.common.utils import list_to_dict
|
|
13 | 13 | @pytest.mark.test_async
|
14 | 14 | class TestAssistant(Base):
|
15 | 15 |
|
16 |
| - retrieval_configs_list = [ |
17 |
| - {"method": "memory", "top_k": 2, "max_tokens": 4000}, |
18 |
| - RetrievalConfig( |
19 |
| - method="memory", |
20 |
| - top_k=1, |
21 |
| - max_tokens=5000, |
22 |
| - |
23 |
| - ) |
24 |
| - ] |
25 |
| - |
26 | 16 | @pytest.mark.run(order=51)
|
27 | 17 | @pytest.mark.asyncio
|
28 | 18 | async def test_a_create_assistant(self):
|
@@ -62,7 +52,11 @@ async def test_a_create_assistant(self):
|
62 | 52 | }
|
63 | 53 | for i in range(4):
|
64 | 54 | if i == 0:
|
| 55 | + assistant_dict.update({"memory": {"type": "naive"}}) |
| 56 | + assistant_dict.update({"retrievals": [{"type": "collection", "id": self.collection_id}]}) |
65 | 57 | assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}})
|
| 58 | + assistant_dict.update({"tools": [{"type": "action", "id": self.action_id}, |
| 59 | + {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) |
66 | 60 | res = await a_create_assistant(**assistant_dict)
|
67 | 61 | res_dict = vars(res)
|
68 | 62 | logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}')
|
@@ -106,22 +100,54 @@ async def test_a_get_assistant(self):
|
106 | 100 |
|
107 | 101 | @pytest.mark.run(order=54)
|
108 | 102 | @pytest.mark.asyncio
|
109 |
| - @pytest.mark.parametrize("retrieval_configs", retrieval_configs_list) |
110 |
| - async def test_a_update_assistant(self, retrieval_configs): |
| 103 | + async def test_a_update_assistant(self): |
111 | 104 |
|
112 | 105 | # Update an assistant.
|
113 | 106 |
|
114 |
| - name = "openai" |
115 |
| - description = "test for openai" |
| 107 | + update_data_list = [ |
| 108 | + { |
| 109 | + "name": "openai", |
| 110 | + "description": "test for openai", |
| 111 | + "memory": AssistantZeroMemory(), |
| 112 | + "retrievals": [ |
| 113 | + RetrievalRef( |
| 114 | + type=RetrievalType.COLLECTION, |
| 115 | + id=self.collection_id, |
| 116 | + ), |
| 117 | + ], |
| 118 | + "retrieval_configs": RetrievalConfig( |
| 119 | + method="memory", |
| 120 | + top_k=2, |
| 121 | + max_tokens=4000, |
116 | 122 |
|
117 |
| - res = await a_update_assistant(assistant_id=self.assistant_id, name=name, description=description, retrieval_configs=retrieval_configs) |
118 |
| - res_dict = vars(res) |
119 |
| - pytest.assume(res_dict["name"] == name) |
120 |
| - pytest.assume(res_dict["description"] == description) |
121 |
| - if isinstance(retrieval_configs, dict): |
122 |
| - pytest.assume(vars(res_dict["retrieval_configs"]) == retrieval_configs) |
123 |
| - else: |
124 |
| - pytest.assume(res_dict["retrieval_configs"] == retrieval_configs) |
| 123 | + ), |
| 124 | + "tools": [ |
| 125 | + ToolRef( |
| 126 | + type=ToolType.ACTION, |
| 127 | + id=self.action_id, |
| 128 | + ), |
| 129 | + ToolRef( |
| 130 | + type=ToolType.PLUGIN, |
| 131 | + id="open_weather/get_hourly_forecast", |
| 132 | + ) |
| 133 | + ] |
| 134 | + }, |
| 135 | + { |
| 136 | + "name": "openai", |
| 137 | + "description": "test for openai", |
| 138 | + "memory": {"type": "naive"}, |
| 139 | + "retrievals": [{"type": "collection", "id": self.collection_id}], |
| 140 | + "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}, |
| 141 | + "tools": [{"type": "action", "id": self.action_id}, |
| 142 | + {"type": "plugin", "id": "open_weather/get_hourly_forecast"}] |
| 143 | + |
| 144 | + } |
| 145 | + ] |
| 146 | + for update_data in update_data_list: |
| 147 | + res = await a_update_assistant(assistant_id=self.assistant_id, **update_data) |
| 148 | + res_dict = vars(res) |
| 149 | + logger.info(f'response_dict:{res_dict}, except_dict:{update_data}') |
| 150 | + assume_assistant_result(update_data, res_dict) |
125 | 151 |
|
126 | 152 | @pytest.mark.run(order=66)
|
127 | 153 | @pytest.mark.asyncio
|
|
0 commit comments