|
2 | 2 |
|
3 | 3 | from taskingai.assistant import * |
4 | 4 | from taskingai.client.models import ToolRef, ToolType, RetrievalRef, RetrievalType, RetrievalConfig |
5 | | -from taskingai.assistant.memory import AssistantNaiveMemory |
| 5 | +from taskingai.assistant.memory import AssistantNaiveMemory, AssistantZeroMemory |
6 | 6 | from test.config import Config |
7 | 7 | from test.common.logger import logger |
8 | 8 | from test.common.utils import list_to_dict |
|
13 | 13 | @pytest.mark.test_async |
14 | 14 | class TestAssistant(Base): |
15 | 15 |
|
16 | | - retrieval_configs_list = [ |
17 | | - {"method": "memory", "top_k": 2, "max_tokens": 4000}, |
18 | | - RetrievalConfig( |
19 | | - method="memory", |
20 | | - top_k=1, |
21 | | - max_tokens=5000, |
22 | | - |
23 | | - ) |
24 | | - ] |
25 | | - |
26 | 16 | @pytest.mark.run(order=51) |
27 | 17 | @pytest.mark.asyncio |
28 | 18 | async def test_a_create_assistant(self): |
@@ -62,7 +52,11 @@ async def test_a_create_assistant(self): |
62 | 52 | } |
63 | 53 | for i in range(4): |
64 | 54 | if i == 0: |
| 55 | + assistant_dict.update({"memory": {"type": "naive"}}) |
| 56 | + assistant_dict.update({"retrievals": [{"type": "collection", "id": self.collection_id}]}) |
65 | 57 | assistant_dict.update({"retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}}) |
| 58 | + assistant_dict.update({"tools": [{"type": "action", "id": self.action_id}, |
| 59 | + {"type": "plugin", "id": "open_weather/get_hourly_forecast"}]}) |
66 | 60 | res = await a_create_assistant(**assistant_dict) |
67 | 61 | res_dict = vars(res) |
68 | 62 | logger.info(f'response_dict:{res_dict}, except_dict:{assistant_dict}') |
@@ -106,22 +100,54 @@ async def test_a_get_assistant(self): |
106 | 100 |
|
107 | 101 | @pytest.mark.run(order=54) |
108 | 102 | @pytest.mark.asyncio |
109 | | - @pytest.mark.parametrize("retrieval_configs", retrieval_configs_list) |
110 | | - async def test_a_update_assistant(self, retrieval_configs): |
| 103 | + async def test_a_update_assistant(self): |
111 | 104 |
|
112 | 105 | # Update an assistant. |
113 | 106 |
|
114 | | - name = "openai" |
115 | | - description = "test for openai" |
| 107 | + update_data_list = [ |
| 108 | + { |
| 109 | + "name": "openai", |
| 110 | + "description": "test for openai", |
| 111 | + "memory": AssistantZeroMemory(), |
| 112 | + "retrievals": [ |
| 113 | + RetrievalRef( |
| 114 | + type=RetrievalType.COLLECTION, |
| 115 | + id=self.collection_id, |
| 116 | + ), |
| 117 | + ], |
| 118 | + "retrieval_configs": RetrievalConfig( |
| 119 | + method="memory", |
| 120 | + top_k=2, |
| 121 | + max_tokens=4000, |
116 | 122 |
|
117 | | - res = await a_update_assistant(assistant_id=self.assistant_id, name=name, description=description, retrieval_configs=retrieval_configs) |
118 | | - res_dict = vars(res) |
119 | | - pytest.assume(res_dict["name"] == name) |
120 | | - pytest.assume(res_dict["description"] == description) |
121 | | - if isinstance(retrieval_configs, dict): |
122 | | - pytest.assume(vars(res_dict["retrieval_configs"]) == retrieval_configs) |
123 | | - else: |
124 | | - pytest.assume(res_dict["retrieval_configs"] == retrieval_configs) |
| 123 | + ), |
| 124 | + "tools": [ |
| 125 | + ToolRef( |
| 126 | + type=ToolType.ACTION, |
| 127 | + id=self.action_id, |
| 128 | + ), |
| 129 | + ToolRef( |
| 130 | + type=ToolType.PLUGIN, |
| 131 | + id="open_weather/get_hourly_forecast", |
| 132 | + ) |
| 133 | + ] |
| 134 | + }, |
| 135 | + { |
| 136 | + "name": "openai", |
| 137 | + "description": "test for openai", |
| 138 | + "memory": {"type": "naive"}, |
| 139 | + "retrievals": [{"type": "collection", "id": self.collection_id}], |
| 140 | + "retrieval_configs": {"method": "memory", "top_k": 2, "max_tokens": 4000}, |
| 141 | + "tools": [{"type": "action", "id": self.action_id}, |
| 142 | + {"type": "plugin", "id": "open_weather/get_hourly_forecast"}] |
| 143 | + |
| 144 | + } |
| 145 | + ] |
| 146 | + for update_data in update_data_list: |
| 147 | + res = await a_update_assistant(assistant_id=self.assistant_id, **update_data) |
| 148 | + res_dict = vars(res) |
| 149 | + logger.info(f'response_dict:{res_dict}, except_dict:{update_data}') |
| 150 | + assume_assistant_result(update_data, res_dict) |
125 | 151 |
|
126 | 152 | @pytest.mark.run(order=66) |
127 | 153 | @pytest.mark.asyncio |
|
0 commit comments