Skip to content

Commit 175c686

Browse files
authored
feat: add submit_input method for thread run (#250)
* feat: add submit_input method for thread run * fix: change gt to gte operator in poll_for_status * fix: add =None to optionals
1 parent 7ff423d commit 175c686

File tree

3 files changed

+77
-7
lines changed

3 files changed

+77
-7
lines changed

ai21/clients/common/beta/assistant/runs.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4-
from typing import List
4+
from typing import Any, List
55

66
from ai21.models.assistant.assistant import Optimization
77
from ai21.models.assistant.run import ToolOutput
@@ -85,3 +85,19 @@ def create_and_poll(
8585
**kwargs,
8686
) -> RunResponse:
8787
pass
88+
89+
@abstractmethod
90+
def submit_input(self, *, thread_id: str, run_id: str, input: Any) -> RunResponse:
91+
pass
92+
93+
@abstractmethod
94+
def submit_input_and_poll(
95+
self,
96+
*,
97+
thread_id: str,
98+
run_id: str,
99+
input: Any,
100+
poll_interval_sec: float,
101+
poll_timeout_sec: float,
102+
) -> RunResponse:
103+
pass

ai21/clients/studio/resources/beta/assistant/thread_runs.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import asyncio
44
import time
5-
from typing import List
5+
from typing import Any, List
66

77
from ai21.clients.common.beta.assistant.runs import BaseRuns
88
from ai21.clients.studio.resources.studio_resource import StudioResource, AsyncStudioResource
@@ -73,7 +73,7 @@ def _poll_for_status(
7373
if run.status in TERMINATED_RUN_STATUSES:
7474
return run
7575

76-
if (time.time() - start_time) > poll_timeout:
76+
if (time.time() - start_time) >= poll_timeout:
7777
return run
7878

7979
time.sleep(poll_interval)
@@ -97,6 +97,30 @@ def create_and_poll(
9797
thread_id=thread_id, run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec
9898
)
9999

100+
def submit_input(self, *, thread_id: str, run_id: str, input: Any) -> RunResponse:
101+
body = dict(input=input)
102+
103+
return self._post(
104+
path=f"/threads/{thread_id}/{self._module_name}/{run_id}/submit_input",
105+
body=body,
106+
response_cls=RunResponse,
107+
)
108+
109+
def submit_input_and_poll(
110+
self,
111+
*,
112+
thread_id: str,
113+
run_id: str,
114+
input: Any,
115+
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
116+
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
117+
) -> RunResponse:
118+
run = self.submit_input(thread_id=thread_id, run_id=run_id, input=input)
119+
120+
return self._poll_for_status(
121+
thread_id=thread_id, run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec
122+
)
123+
100124

101125
class AsyncThreadRuns(AsyncStudioResource, BaseRuns):
102126
async def create(
@@ -156,7 +180,7 @@ async def _poll_for_status(
156180
if run.status in TERMINATED_RUN_STATUSES:
157181
return run
158182

159-
if (time.time() - start_time) > poll_timeout:
183+
if (time.time() - start_time) >= poll_timeout:
160184
return run
161185

162186
await asyncio.sleep(poll_interval)
@@ -179,3 +203,27 @@ async def create_and_poll(
179203
return await self._poll_for_status(
180204
thread_id=thread_id, run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec
181205
)
206+
207+
async def submit_input(self, *, thread_id: str, run_id: str, input: Any) -> RunResponse:
208+
body = dict(input=input)
209+
210+
return await self._post(
211+
path=f"/threads/{thread_id}/{self._module_name}/{run_id}/submit_inputs",
212+
body=body,
213+
response_cls=RunResponse,
214+
)
215+
216+
async def submit_input_and_poll(
217+
self,
218+
*,
219+
thread_id: str,
220+
run_id: str,
221+
input: Any,
222+
poll_interval_sec: float = DEFAULT_RUN_POLL_INTERVAL,
223+
poll_timeout_sec: float = DEFAULT_RUN_POLL_TIMEOUT,
224+
) -> RunResponse:
225+
run = await self.submit_input(thread_id=thread_id, run_id=run_id, input=input)
226+
227+
return await self._poll_for_status(
228+
thread_id=thread_id, run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec
229+
)

ai21/models/assistant/run.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal, Any, List, Set
1+
from typing import Dict, Literal, Any, List, Set, Optional
22

33
from typing_extensions import TypedDict
44

@@ -40,6 +40,12 @@ class SubmitToolCallOutputs(TypedDict):
4040
tool_calls: List[ToolOutput]
4141

4242

43+
class SubmitInput(TypedDict):
44+
event_name: str
45+
data: Dict[str, Any]
46+
47+
4348
class RequiredAction(TypedDict):
44-
type: Literal["submit_tool_outputs"]
45-
submit_tool_outputs: SubmitToolCallOutputs
49+
type: Literal["submit_tool_outputs", "submit_input"]
50+
submit_tool_outputs: Optional[SubmitToolCallOutputs] = None
51+
submit_input: Optional[SubmitInput] = None

0 commit comments

Comments
 (0)