22
33import asyncio
44import time
5- from typing import List
5+ from typing import Any , List
66
77from ai21 .clients .common .beta .assistant .runs import BaseRuns
88from ai21 .clients .studio .resources .studio_resource import StudioResource , AsyncStudioResource
@@ -73,7 +73,7 @@ def _poll_for_status(
7373 if run .status in TERMINATED_RUN_STATUSES :
7474 return run
7575
76- if (time .time () - start_time ) > poll_timeout :
76+ if (time .time () - start_time ) >= poll_timeout :
7777 return run
7878
7979 time .sleep (poll_interval )
@@ -97,6 +97,30 @@ def create_and_poll(
9797 thread_id = thread_id , run_id = run .id , poll_interval = poll_interval_sec , poll_timeout = poll_timeout_sec
9898 )
9999
100+ def submit_input (self , * , thread_id : str , run_id : str , input : Any ) -> RunResponse :
101+ body = dict (input = input )
102+
103+ return self ._post (
104+ path = f"/threads/{ thread_id } /{ self ._module_name } /{ run_id } /submit_input" ,
105+ body = body ,
106+ response_cls = RunResponse ,
107+ )
108+
109+ def submit_input_and_poll (
110+ self ,
111+ * ,
112+ thread_id : str ,
113+ run_id : str ,
114+ input : Any ,
115+ poll_interval_sec : float = DEFAULT_RUN_POLL_INTERVAL ,
116+ poll_timeout_sec : float = DEFAULT_RUN_POLL_TIMEOUT ,
117+ ) -> RunResponse :
118+ run = self .submit_input (thread_id = thread_id , run_id = run_id , input = input )
119+
120+ return self ._poll_for_status (
121+ thread_id = thread_id , run_id = run .id , poll_interval = poll_interval_sec , poll_timeout = poll_timeout_sec
122+ )
123+
100124
101125class AsyncThreadRuns (AsyncStudioResource , BaseRuns ):
102126 async def create (
@@ -156,7 +180,7 @@ async def _poll_for_status(
156180 if run .status in TERMINATED_RUN_STATUSES :
157181 return run
158182
159- if (time .time () - start_time ) > poll_timeout :
183+ if (time .time () - start_time ) >= poll_timeout :
160184 return run
161185
162186 await asyncio .sleep (poll_interval )
@@ -179,3 +203,27 @@ async def create_and_poll(
179203 return await self ._poll_for_status (
180204 thread_id = thread_id , run_id = run .id , poll_interval = poll_interval_sec , poll_timeout = poll_timeout_sec
181205 )
206+
207+ async def submit_input (self , * , thread_id : str , run_id : str , input : Any ) -> RunResponse :
208+ body = dict (input = input )
209+
210+ return await self ._post (
211+ path = f"/threads/{ thread_id } /{ self ._module_name } /{ run_id } /submit_inputs" ,
212+ body = body ,
213+ response_cls = RunResponse ,
214+ )
215+
216+ async def submit_input_and_poll (
217+ self ,
218+ * ,
219+ thread_id : str ,
220+ run_id : str ,
221+ input : Any ,
222+ poll_interval_sec : float = DEFAULT_RUN_POLL_INTERVAL ,
223+ poll_timeout_sec : float = DEFAULT_RUN_POLL_TIMEOUT ,
224+ ) -> RunResponse :
225+ run = await self .submit_input (thread_id = thread_id , run_id = run_id , input = input )
226+
227+ return await self ._poll_for_status (
228+ thread_id = thread_id , run_id = run .id , poll_interval = poll_interval_sec , poll_timeout = poll_timeout_sec
229+ )
0 commit comments