Skip to content

Commit 54d9e41

Browse files
committed
feat(lib): add interactive mode for automate
1 parent 9710840 commit 54d9e41

3 files changed

Lines changed: 298 additions & 0 deletions

File tree

src/tabstack/lib/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from ._models import (
2+
InputFormField as InputFormField,
3+
InputFormEvent as InputFormEvent,
4+
InputFormResponse as InputFormResponse,
5+
InputDeclinedResponse as InputDeclinedResponse,
6+
InputExpiredError as InputExpiredError,
7+
)
8+
from ._interactive import (
9+
automate_interactive as automate_interactive,
10+
async_automate_interactive as async_automate_interactive,
11+
)

src/tabstack/lib/_interactive.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict, Iterator, AsyncIterator, Callable, Awaitable, Union
4+
5+
import httpx
6+
7+
from .._types import Body, Omit, Query, Headers, NotGiven, omit, not_given
8+
from .._compat import model_parse
9+
from .._exceptions import APIStatusError
10+
from ..types.automate_event import AutomateEvent
11+
from ..types.agent_automate_params import GeoTarget
12+
from ._models import (
13+
InputFormEvent,
14+
InputFormResponse,
15+
InputDeclinedResponse,
16+
InputExpiredError,
17+
)
18+
19+
__all__ = [
20+
"automate_interactive",
21+
"async_automate_interactive",
22+
]
23+
24+
25+
def _build_input_payload(response: Union[InputFormResponse, InputDeclinedResponse]) -> Dict[str, Any]:
26+
if isinstance(response, InputFormResponse):
27+
return {"type": "form", "fields": response.fields}
28+
return {"type": "declined", "reason": response.reason}
29+
30+
31+
def _submit_input(
32+
client: Any,
33+
question_id: str,
34+
payload: Dict[str, Any],
35+
) -> None:
36+
try:
37+
client.post(
38+
f"/automate/{question_id}/input",
39+
cast_to=httpx.Response,
40+
body=payload,
41+
)
42+
except APIStatusError as e:
43+
if e.status_code == 410:
44+
raise InputExpiredError(question_id) from e
45+
raise
46+
47+
48+
async def _async_submit_input(
49+
client: Any,
50+
question_id: str,
51+
payload: Dict[str, Any],
52+
) -> None:
53+
try:
54+
await client.post(
55+
f"/automate/{question_id}/input",
56+
cast_to=httpx.Response,
57+
body=payload,
58+
)
59+
except APIStatusError as e:
60+
if e.status_code == 410:
61+
raise InputExpiredError(question_id) from e
62+
raise
63+
64+
65+
def automate_interactive(
66+
client: Any,
67+
*,
68+
task: str,
69+
on_input: Callable[[InputFormEvent], Union[InputFormResponse, InputDeclinedResponse]],
70+
data: object | Omit = omit,
71+
geo_target: GeoTarget | Omit = omit,
72+
guardrails: str | Omit = omit,
73+
max_iterations: int | Omit = omit,
74+
max_validation_attempts: int | Omit = omit,
75+
url: str | Omit = omit,
76+
extra_headers: Headers | None = None,
77+
extra_query: Query | None = None,
78+
extra_body: Body | None = None,
79+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
80+
) -> Iterator[AutomateEvent]:
81+
"""Run an interactive automate session with a callback for input requests.
82+
83+
Wraps ``client.agent.automate()`` with ``interactive: true`` and intercepts
84+
``input:form`` events. When one arrives, ``on_input`` is called with the
85+
parsed form data. The callback's return value is submitted to the API
86+
automatically.
87+
88+
All events (including ``input:form``) are yielded to the caller.
89+
90+
Args:
91+
client: A ``Tabstack`` client instance.
92+
task: The task description in natural language.
93+
on_input: Callback invoked for each ``input:form`` event. Return
94+
``InputFormResponse`` to fill the form or ``InputDeclinedResponse``
95+
to decline.
96+
data: JSON data to provide context for form filling or complex tasks.
97+
geo_target: Optional geotargeting parameters.
98+
guardrails: Safety constraints for execution.
99+
max_iterations: Maximum task iterations.
100+
max_validation_attempts: Maximum validation attempts.
101+
url: Starting URL for the task.
102+
extra_headers: Send extra headers.
103+
extra_query: Add additional query parameters to the request.
104+
extra_body: Add additional JSON properties to the request.
105+
timeout: Override the client-level default timeout for this request, in seconds.
106+
"""
107+
merged_extra_body: dict[str, Any] = {**(extra_body or {}), "interactive": True}
108+
109+
stream = client.agent.automate(
110+
task=task,
111+
data=data,
112+
geo_target=geo_target,
113+
guardrails=guardrails,
114+
max_iterations=max_iterations,
115+
max_validation_attempts=max_validation_attempts,
116+
url=url,
117+
extra_headers=extra_headers,
118+
extra_query=extra_query,
119+
extra_body=merged_extra_body,
120+
timeout=timeout,
121+
)
122+
123+
try:
124+
for event in stream:
125+
if event.event == "input:form" and isinstance(event.data, dict):
126+
form_event = model_parse(InputFormEvent, event.data)
127+
callback_response = on_input(form_event)
128+
payload = _build_input_payload(callback_response)
129+
_submit_input(client, form_event.question_id, payload)
130+
131+
yield event
132+
finally:
133+
stream.response.close()
134+
135+
136+
async def async_automate_interactive(
137+
client: Any,
138+
*,
139+
task: str,
140+
on_input: Callable[[InputFormEvent], Awaitable[Union[InputFormResponse, InputDeclinedResponse]]],
141+
data: object | Omit = omit,
142+
geo_target: GeoTarget | Omit = omit,
143+
guardrails: str | Omit = omit,
144+
max_iterations: int | Omit = omit,
145+
max_validation_attempts: int | Omit = omit,
146+
url: str | Omit = omit,
147+
extra_headers: Headers | None = None,
148+
extra_query: Query | None = None,
149+
extra_body: Body | None = None,
150+
timeout: float | httpx.Timeout | None | NotGiven = not_given,
151+
) -> AsyncIterator[AutomateEvent]:
152+
"""Async version of :func:`automate_interactive`.
153+
154+
The ``on_input`` callback must be an async function.
155+
156+
Args:
157+
client: An ``AsyncTabstack`` client instance.
158+
task: The task description in natural language.
159+
on_input: Async callback invoked for each ``input:form`` event. Return
160+
``InputFormResponse`` to fill the form or ``InputDeclinedResponse``
161+
to decline.
162+
data: JSON data to provide context for form filling or complex tasks.
163+
geo_target: Optional geotargeting parameters.
164+
guardrails: Safety constraints for execution.
165+
max_iterations: Maximum task iterations.
166+
max_validation_attempts: Maximum validation attempts.
167+
url: Starting URL for the task.
168+
extra_headers: Send extra headers.
169+
extra_query: Add additional query parameters to the request.
170+
extra_body: Add additional JSON properties to the request.
171+
timeout: Override the client-level default timeout for this request, in seconds.
172+
"""
173+
merged_extra_body: dict[str, Any] = {**(extra_body or {}), "interactive": True}
174+
175+
stream = await client.agent.automate(
176+
task=task,
177+
data=data,
178+
geo_target=geo_target,
179+
guardrails=guardrails,
180+
max_iterations=max_iterations,
181+
max_validation_attempts=max_validation_attempts,
182+
url=url,
183+
extra_headers=extra_headers,
184+
extra_query=extra_query,
185+
extra_body=merged_extra_body,
186+
timeout=timeout,
187+
)
188+
189+
try:
190+
async for event in stream:
191+
if event.event == "input:form" and isinstance(event.data, dict):
192+
form_event = model_parse(InputFormEvent, event.data)
193+
callback_response = await on_input(form_event)
194+
payload = _build_input_payload(callback_response)
195+
await _async_submit_input(client, form_event.question_id, payload)
196+
197+
yield event
198+
finally:
199+
await stream.response.aclose()

src/tabstack/lib/_models.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
from __future__ import annotations
2+
3+
from typing import Any, Dict, List, Optional
4+
5+
import pydantic
6+
7+
from .._models import BaseModel
8+
from .._compat import PYDANTIC_V1
9+
from .._exceptions import TabstackError
10+
11+
__all__ = [
12+
"InputFormField",
13+
"InputFormEvent",
14+
"InputFormResponse",
15+
"InputDeclinedResponse",
16+
"InputExpiredError",
17+
]
18+
19+
20+
class InputFormField(BaseModel):
21+
"""A single field in an input form request."""
22+
23+
name: str
24+
"""Field identifier used as the key in the response."""
25+
26+
label: str
27+
"""Human-readable display label."""
28+
29+
sensitive: bool = False
30+
"""Whether this field contains sensitive data (e.g., passwords)."""
31+
32+
33+
class InputFormEvent(BaseModel):
34+
"""Parsed payload from an `input:form` SSE event.
35+
36+
The automation agent has encountered a form and is requesting
37+
the caller to provide values for one or more fields.
38+
"""
39+
40+
if PYDANTIC_V1:
41+
42+
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
43+
extra: Any = pydantic.Extra.allow # type: ignore
44+
allow_population_by_field_name = True
45+
else:
46+
model_config = pydantic.ConfigDict(extra="allow", populate_by_name=True)
47+
48+
question_id: str = pydantic.Field(alias="questionId")
49+
"""Unique identifier for this input request."""
50+
51+
question: str
52+
"""Human-readable prompt describing what input is needed."""
53+
54+
fields: List[InputFormField]
55+
"""The form fields to be filled in."""
56+
57+
page_url: Optional[str] = pydantic.Field(default=None, alias="pageUrl")
58+
"""URL of the page where the form was encountered."""
59+
60+
page_title: Optional[str] = pydantic.Field(default=None, alias="pageTitle")
61+
"""Title of the page where the form was encountered."""
62+
63+
timeout_ms: int = pydantic.Field(alias="timeoutMs")
64+
"""How long the server will wait for a response, in milliseconds."""
65+
66+
67+
class InputFormResponse(BaseModel):
68+
"""Returned by the callback to provide form field values."""
69+
70+
fields: Dict[str, Any]
71+
"""Field values keyed by field name."""
72+
73+
74+
class InputDeclinedResponse(BaseModel):
75+
"""Returned by the callback to decline an input request."""
76+
77+
reason: str = ""
78+
"""Optional reason for declining."""
79+
80+
81+
class InputExpiredError(TabstackError):
82+
"""Raised when the input form has expired or was already answered (410 Gone)."""
83+
84+
question_id: str
85+
86+
def __init__(self, question_id: str) -> None:
87+
super().__init__(f"Input request '{question_id}' has expired or was already answered")
88+
self.question_id = question_id

0 commit comments

Comments
 (0)