Skip to content

Commit 945f676

Browse files
authored
Open AI: Running Threads synchronously (#143)
* first stab at adding sync api * refactoring few things * added testcases * fixing few testcases * added few more testcases
1 parent 7f8353d commit 945f676

File tree

2 files changed

+409
-83
lines changed

2 files changed

+409
-83
lines changed

backend/app/api/routes/threads.py

Lines changed: 139 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,89 @@ def send_callback(callback_url: str, data: dict):
2929
return False
3030

3131

32+
def handle_openai_error(e: openai.OpenAIError) -> str:
33+
"""Extract error message from OpenAI error."""
34+
if isinstance(e.body, dict) and "message" in e.body:
35+
return e.body["message"]
36+
return str(e)
37+
38+
39+
def validate_thread(client: OpenAI, thread_id: str) -> tuple[bool, str]:
40+
"""Validate if a thread exists and has no active runs."""
41+
if not thread_id:
42+
return True, None
43+
44+
try:
45+
runs = client.beta.threads.runs.list(thread_id=thread_id)
46+
if runs.data and len(runs.data) > 0:
47+
latest_run = runs.data[0]
48+
if latest_run.status in ["queued", "in_progress", "requires_action"]:
49+
return (
50+
False,
51+
f"There is an active run on this thread (status: {latest_run.status}). Please wait for it to complete.",
52+
)
53+
return True, None
54+
except openai.OpenAIError:
55+
return False, f"Invalid thread ID provided {thread_id}"
56+
57+
58+
def setup_thread(client: OpenAI, request: dict) -> tuple[bool, str]:
59+
"""Set up thread and add message, either creating new or using existing."""
60+
thread_id = request.get("thread_id")
61+
62+
if thread_id:
63+
try:
64+
client.beta.threads.messages.create(
65+
thread_id=thread_id, role="user", content=request["question"]
66+
)
67+
return True, None
68+
except openai.OpenAIError as e:
69+
return False, handle_openai_error(e)
70+
else:
71+
try:
72+
thread = client.beta.threads.create()
73+
client.beta.threads.messages.create(
74+
thread_id=thread.id, role="user", content=request["question"]
75+
)
76+
request["thread_id"] = thread.id
77+
return True, None
78+
except openai.OpenAIError as e:
79+
return False, handle_openai_error(e)
80+
81+
82+
def process_message_content(message_content: str, remove_citation: bool) -> str:
83+
"""Process message content, optionally removing citations."""
84+
if remove_citation:
85+
return re.sub(r"【\d+(?::\d+)?†[^】]*】", "", message_content)
86+
return message_content
87+
88+
89+
def get_additional_data(request: dict) -> dict:
90+
"""Extract additional data from request, excluding specific keys."""
91+
return {
92+
k: v
93+
for k, v in request.items()
94+
if k not in {"question", "assistant_id", "callback_url", "thread_id"}
95+
}
96+
97+
98+
def create_success_response(request: dict, message: str) -> APIResponse:
99+
"""Create a success response with the given message and request data."""
100+
additional_data = get_additional_data(request)
101+
return APIResponse.success_response(
102+
data={
103+
"status": "success",
104+
"message": message,
105+
"thread_id": request["thread_id"],
106+
"endpoint": getattr(request, "endpoint", "some-default-endpoint"),
107+
**additional_data,
108+
}
109+
)
110+
111+
32112
def process_run(request: dict, client: OpenAI):
33-
"""
34-
Background task to run create_and_poll, then send the callback with the result.
35-
This function is run in the background after we have already returned an initial response.
36-
"""
113+
"""Process a run and send callback with results."""
37114
try:
38-
# Start the run
39115
run = client.beta.threads.runs.create_and_poll(
40116
thread_id=request["thread_id"],
41117
assistant_id=request["assistant_id"],
@@ -45,46 +121,19 @@ def process_run(request: dict, client: OpenAI):
45121
messages = client.beta.threads.messages.list(thread_id=request["thread_id"])
46122
latest_message = messages.data[0]
47123
message_content = latest_message.content[0].text.value
48-
49-
remove_citation = request.get("remove_citation", False)
50-
51-
if remove_citation:
52-
message = re.sub(r"【\d+(?::\d+)?†[^】]*】", "", message_content)
53-
else:
54-
message = message_content
55-
56-
# Update the data dictionary with additional fields from the request, excluding specific keys
57-
additional_data = {
58-
k: v
59-
for k, v in request.items()
60-
if k not in {"question", "assistant_id", "callback_url", "thread_id"}
61-
}
62-
callback_response = APIResponse.success_response(
63-
data={
64-
"status": "success",
65-
"message": message,
66-
"thread_id": request["thread_id"],
67-
"endpoint": getattr(request, "endpoint", "some-default-endpoint"),
68-
**additional_data,
69-
}
124+
message = process_message_content(
125+
message_content, request.get("remove_citation", False)
70126
)
127+
callback_response = create_success_response(request, message)
71128
else:
72129
callback_response = APIResponse.failure_response(
73130
error=f"Run failed with status: {run.status}"
74131
)
75132

76-
# Send callback with results
77133
send_callback(request["callback_url"], callback_response.model_dump())
78134

79135
except openai.OpenAIError as e:
80-
# Handle any other OpenAI API errors
81-
if isinstance(e.body, dict) and "message" in e.body:
82-
error_message = e.body["message"]
83-
else:
84-
error_message = str(e)
85-
86-
callback_response = APIResponse.failure_response(error=error_message)
87-
136+
callback_response = APIResponse.failure_response(error=handle_openai_error(e))
88137
send_callback(request["callback_url"], callback_response.model_dump())
89138

90139

@@ -95,54 +144,20 @@ async def threads(
95144
_session: Session = Depends(get_db),
96145
_current_user: UserOrganization = Depends(get_current_user_org),
97146
):
98-
"""
99-
Accepts a question, assistant_id, callback_url, and optional thread_id from the request body.
100-
Returns an immediate "processing" response, then continues to run create_and_poll in background.
101-
Once completed, calls send_callback with the final result.
102-
"""
147+
"""Asynchronous endpoint that processes requests in background."""
103148
client = OpenAI(api_key=settings.OPENAI_API_KEY)
104149

105-
# Use get method to safely access thread_id
106-
thread_id = request.get("thread_id")
150+
# Validate thread
151+
is_valid, error_message = validate_thread(client, request.get("thread_id"))
152+
if not is_valid:
153+
return APIResponse.failure_response(error=error_message)
107154

108-
# 1. Validate or check if there's an existing thread with an in-progress run
109-
if thread_id:
110-
try:
111-
runs = client.beta.threads.runs.list(thread_id=thread_id)
112-
# Get the most recent run (first in the list) if any
113-
if runs.data and len(runs.data) > 0:
114-
latest_run = runs.data[0]
115-
if latest_run.status in ["queued", "in_progress", "requires_action"]:
116-
return APIResponse.failure_response(
117-
error=f"There is an active run on this thread (status: {latest_run.status}). Please wait for it to complete."
118-
)
119-
except openai.OpenAIError:
120-
# Handle invalid thread ID
121-
return APIResponse.failure_response(
122-
error=f"Invalid thread ID provided {thread_id}"
123-
)
155+
# Setup thread
156+
is_success, error_message = setup_thread(client, request)
157+
if not is_success:
158+
return APIResponse.failure_response(error=error_message)
124159

125-
# Use existing thread
126-
client.beta.threads.messages.create(
127-
thread_id=thread_id, role="user", content=request["question"]
128-
)
129-
else:
130-
try:
131-
# Create new thread
132-
thread = client.beta.threads.create()
133-
client.beta.threads.messages.create(
134-
thread_id=thread.id, role="user", content=request["question"]
135-
)
136-
request["thread_id"] = thread.id
137-
except openai.OpenAIError as e:
138-
# Handle any other OpenAI API errors
139-
if isinstance(e.body, dict) and "message" in e.body:
140-
error_message = e.body["message"]
141-
else:
142-
error_message = str(e)
143-
return APIResponse.failure_response(error=error_message)
144-
145-
# 2. Send immediate response to complete the API call
160+
# Send immediate response
146161
initial_response = APIResponse.success_response(
147162
data={
148163
"status": "processing",
@@ -152,8 +167,50 @@ async def threads(
152167
}
153168
)
154169

155-
# 3. Schedule the background task to run create_and_poll and send callback
170+
# Schedule background task
156171
background_tasks.add_task(process_run, request, client)
157172

158-
# 4. Return immediately so the client knows we've accepted the request
159173
return initial_response
174+
175+
176+
@router.post("/threads/sync")
177+
async def threads_sync(
178+
request: dict,
179+
_session: Session = Depends(get_db),
180+
_current_user: UserOrganization = Depends(get_current_user_org),
181+
):
182+
"""Synchronous endpoint that processes requests immediately."""
183+
client = OpenAI(api_key=settings.OPENAI_API_KEY)
184+
185+
# Validate thread
186+
is_valid, error_message = validate_thread(client, request.get("thread_id"))
187+
if not is_valid:
188+
return APIResponse.failure_response(error=error_message)
189+
190+
# Setup thread
191+
is_success, error_message = setup_thread(client, request)
192+
if not is_success:
193+
return APIResponse.failure_response(error=error_message)
194+
195+
try:
196+
# Process run
197+
run = client.beta.threads.runs.create_and_poll(
198+
thread_id=request["thread_id"],
199+
assistant_id=request["assistant_id"],
200+
)
201+
202+
if run.status == "completed":
203+
messages = client.beta.threads.messages.list(thread_id=request["thread_id"])
204+
latest_message = messages.data[0]
205+
message_content = latest_message.content[0].text.value
206+
message = process_message_content(
207+
message_content, request.get("remove_citation", False)
208+
)
209+
return create_success_response(request, message)
210+
else:
211+
return APIResponse.failure_response(
212+
error=f"Run failed with status: {run.status}"
213+
)
214+
215+
except openai.OpenAIError as e:
216+
return APIResponse.failure_response(error=handle_openai_error(e))

0 commit comments

Comments
 (0)