|
4 | 4 |
|
5 | 5 | from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException |
6 | 6 | from openai import OpenAI |
| 7 | +from pydantic import BaseModel, Field |
7 | 8 | from sqlmodel import Session |
| 9 | +from typing import Optional |
8 | 10 | from langfuse.decorators import observe, langfuse_context |
9 | 11 |
|
10 | 12 | from app.api.deps import get_current_user_org, get_db |
|
19 | 21 | router = APIRouter(tags=["threads"]) |
20 | 22 |
|
21 | 23 |
|
| 24 | +class StartThreadRequest(BaseModel): |
| 25 | + question: str = Field(..., description="The user's input question.") |
| 26 | + assistant_id: str = Field(..., description="The ID of the assistant to be used.") |
| 27 | + remove_citation: bool = Field( |
| 28 | + default=False, description="Whether to remove citations from the response." |
| 29 | + ) |
| 30 | + thread_id: Optional[str] = Field( |
| 31 | + default=None, |
| 32 | + description="An optional existing thread ID to continue the conversation.", |
| 33 | + ) |
| 34 | + |
| 35 | + |
22 | 36 | def send_callback(callback_url: str, data: dict): |
23 | 37 | """Send results to the callback URL (synchronously).""" |
24 | 38 | try: |
@@ -340,14 +354,15 @@ async def threads_sync( |
340 | 354 |
|
341 | 355 | @router.post("/threads/start") |
342 | 356 | async def start_thread( |
343 | | - request: dict, |
| 357 | + request: StartThreadRequest, |
344 | 358 | background_tasks: BackgroundTasks, |
345 | 359 | db: Session = Depends(get_db), |
346 | 360 | _current_user: UserOrganization = Depends(get_current_user_org), |
347 | 361 | ): |
348 | 362 | """ |
349 | 363 | Create a new OpenAI thread for the given question and start polling in the background. |
350 | 364 | """ |
| 365 | + request = request.model_dump() |
351 | 366 | prompt = request["question"] |
352 | 367 | credentials = get_provider_credential( |
353 | 368 | session=db, |
|
0 commit comments