diff --git a/backend/app/api/docs/llm/get_llm_call.md b/backend/app/api/docs/llm/get_llm_call.md new file mode 100644 index 000000000..ab6217d20 --- /dev/null +++ b/backend/app/api/docs/llm/get_llm_call.md @@ -0,0 +1,10 @@ +Retrieve the status and results of an LLM call job by job ID. + +This endpoint allows you to poll for the status and results of an asynchronous LLM call job that was previously initiated via the POST `/llm/call` endpoint. + + +### Notes + +- This endpoint returns both the job status AND the actual LLM response when complete +- LLM responses are also delivered asynchronously via the callback URL (if provided) +- Jobs can be queried at any time after creation diff --git a/backend/app/api/routes/llm.py b/backend/app/api/routes/llm.py index ec48803ce..a4cd705c6 100644 --- a/backend/app/api/routes/llm.py +++ b/backend/app/api/routes/llm.py @@ -1,10 +1,20 @@ import logging +from uuid import UUID -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission -from app.models import LLMCallRequest, LLMCallResponse, Message +from app.crud.jobs import JobCrud +from app.crud.llm import get_llm_calls_by_job_id +from app.models import ( + LLMCallRequest, + LLMCallResponse, + LLMJobImmediatePublic, + LLMJobPublic, + JobStatus, +) +from app.models.llm.response import LLMResponse, Usage from app.services.llm.jobs import start_job from app.utils import APIResponse, validate_callback_url, load_description @@ -34,7 +44,7 @@ def llm_callback_notification(body: APIResponse[LLMCallResponse]): @router.post( "/llm/call", description=load_description("llm/llm_call.md"), - response_model=APIResponse[Message], + response_model=APIResponse[LLMJobImmediatePublic], callbacks=llm_callback_router.routes, dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], ) @@ -43,6 +53,7 @@ def llm_call( ): """ Endpoint to initiate an LLM call as a background job. + Returns job information for polling. """ project_id = _current_user.project_.id organization_id = _current_user.organization_.id @@ -50,15 +61,93 @@ def llm_call( if request.callback_url: validate_callback_url(str(request.callback_url)) - start_job( + job_id = start_job( db=session, request=request, project_id=project_id, organization_id=organization_id, ) - return APIResponse.success_response( - data=Message( - message=f"Your response is being generated and will be delivered via callback." - ), + # Fetch job details to return immediate response + job_crud = JobCrud(session=session) + job = job_crud.get(job_id=job_id) + + if not job: + raise HTTPException(status_code=404, detail="Job not found") + + if request.callback_url: + message = "Your response is being generated and will be delivered via callback." + else: + message = "Your response is being generated" + + job_response = LLMJobImmediatePublic( + job_id=job.id, + status=job.status.value, + message=message, + job_inserted_at=job.created_at, + job_updated_at=job.updated_at, ) + + return APIResponse.success_response(data=job_response) + + +@router.get( + "/llm/call/{job_id}", + description=load_description("llm/get_llm_call.md"), + response_model=APIResponse[LLMJobPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def get_llm_call_status( + _current_user: AuthContextDep, + session: SessionDep, + job_id: UUID, +) -> APIResponse[LLMJobPublic]: + """ + Poll for LLM call job status and results. + Returns job information with nested LLM response when complete. + """ + job_crud = JobCrud(session=session) + job = job_crud.get(job_id=job_id) + + if not job: + raise HTTPException(status_code=404, detail="Job not found") + + llm_call_response = None + if job.status.value == JobStatus.SUCCESS: + llm_calls = get_llm_calls_by_job_id( + session=session, job_id=job_id, project_id=_current_user.project_.id + ) + + if llm_calls: + # Get the first LLM call from the list which will be the only call for the job id + # since we initially won't be using this endpoint for llm chains + llm_call = llm_calls[0] + + llm_response = LLMResponse( + provider_response_id=llm_call.provider_response_id or "", + conversation_id=llm_call.conversation_id, + provider=llm_call.provider, + model=llm_call.model, + output=llm_call.content, + ) + + if not llm_call.usage: + raise HTTPException( + status_code=500, + detail="Completed LLM job is missing usage data", + ) + + llm_call_response = LLMCallResponse( + response=llm_response, + usage=Usage(**llm_call.usage), + provider_raw_response=None, + ) + + job_response = LLMJobPublic( + job_id=job.id, + status=job.status.value, + llm_response=llm_call_response, + error_message=job.error_message, + ) + + return APIResponse.success_response(data=job_response) diff --git a/backend/app/crud/llm.py b/backend/app/crud/llm.py index e0ca2b171..c7f5b1aee 100644 --- a/backend/app/crud/llm.py +++ b/backend/app/crud/llm.py @@ -1,11 +1,12 @@ import logging +import base64 +import json +from uuid import UUID from typing import Any, Literal -from uuid import UUID from sqlmodel import Session, select + from app.core.util import now -import base64 -import json from app.models.llm import LlmCall, LLMCallRequest, ConfigBlob from app.models.llm.request import ( TextInput, @@ -234,13 +235,13 @@ def get_llm_call_by_id( def get_llm_calls_by_job_id( - session: Session, - job_id: UUID, + session: Session, job_id: UUID, project_id: int ) -> list[LlmCall]: statement = ( select(LlmCall) .where( LlmCall.job_id == job_id, + LlmCall.project_id == project_id, LlmCall.deleted_at.is_(None), ) .order_by(LlmCall.created_at.desc()) diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index b5cb3f0c6..64856cdf9 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -115,6 +115,9 @@ LLMChainRequest, LLMChainResponse, LlmChain, + LLMJobBasePublic, + LLMJobImmediatePublic, + LLMJobPublic, ) from .message import Message diff --git a/backend/app/models/llm/__init__.py b/backend/app/models/llm/__init__.py index 1cb659f85..5d7d9326f 100644 --- a/backend/app/models/llm/__init__.py +++ b/backend/app/models/llm/__init__.py @@ -30,4 +30,7 @@ AudioOutput, LLMChainResponse, IntermediateChainResponse, + LLMJobBasePublic, + LLMJobImmediatePublic, + LLMJobPublic, ) diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 8cc5f5c3e..65726de28 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -7,6 +7,7 @@ from pydantic import HttpUrl, model_validator from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field, Index, SQLModel, text + from app.core.util import now from app.models.llm.constants import ( DEFAULT_STT_MODEL, diff --git a/backend/app/models/llm/response.py b/backend/app/models/llm/response.py index fff62ea76..dfbcbf4f6 100644 --- a/backend/app/models/llm/response.py +++ b/backend/app/models/llm/response.py @@ -3,8 +3,12 @@ This module contains structured response models for LLM API calls. """ -from sqlmodel import SQLModel, Field +from datetime import datetime +from uuid import UUID from typing import Literal, Annotated + +from sqlmodel import SQLModel, Field + from app.models.llm.request import AudioContent, TextContent @@ -100,3 +104,26 @@ class IntermediateChainResponse(SQLModel): default=None, description="Unmodified raw response from the LLM provider from the current block", ) + + +# Job response models +class LLMJobBasePublic(SQLModel): + """Base response model for LLM job information.""" + + job_id: UUID + status: str # JobStatus from job.py + + +class LLMJobImmediatePublic(LLMJobBasePublic): + """Immediate response after creating an LLM job.""" + + message: str + job_inserted_at: datetime + job_updated_at: datetime + + +class LLMJobPublic(LLMJobBasePublic): + """Full job response with nested LLM response when complete.""" + + llm_response: LLMCallResponse | None = None + error_message: str | None = None diff --git a/backend/app/tests/api/routes/test_llm.py b/backend/app/tests/api/routes/test_llm.py index 5031effd9..885076bb0 100644 --- a/backend/app/tests/api/routes/test_llm.py +++ b/backend/app/tests/api/routes/test_llm.py @@ -1,15 +1,40 @@ +import pytest +from uuid import uuid4 from unittest.mock import patch +from sqlmodel import Session from fastapi.testclient import TestClient -from app.models import LLMCallRequest +from app.crud import JobCrud +from app.models import Job, JobStatus, JobUpdate +from app.models.llm.response import LLMCallResponse from app.models.llm.request import ( - QueryParams, LLMCallConfig, ConfigBlob, - KaapiCompletionConfig, NativeCompletionConfig, + KaapiCompletionConfig, + QueryParams, ) +from app.models.llm import LLMCallRequest +from app.tests.utils.auth import TestAuthContext +from app.tests.utils.llm import create_llm_job, create_llm_call_with_response + + +@pytest.fixture +def llm_job(db: Session) -> Job: + return create_llm_job(db) + + +@pytest.fixture +def llm_response_in_db( + db: Session, llm_job: Job, user_api_key: TestAuthContext +) -> LLMCallResponse: + return create_llm_call_with_response( + db, + job_id=llm_job.id, + project_id=user_api_key.project_id, + organization_id=user_api_key.organization_id, + ) def test_llm_call_success( @@ -247,3 +272,88 @@ def test_llm_call_guardrails_bypassed_still_succeeds( assert "response is being generated" in body["data"]["message"] mock_start_job.assert_called_once() + + +def test_get_llm_call_pending( + client: TestClient, + user_api_key_header: dict[str, str], + llm_job, +) -> None: + """Job in PENDING state returns status with no llm_response.""" + response = client.get( + f"/api/v1/llm/call/{llm_job.id}", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + assert body["data"]["job_id"] == str(llm_job.id) + assert body["data"]["status"] == "PENDING" + assert body["data"]["llm_response"] is None + + +def test_get_llm_call_success( + client: TestClient, + db: Session, + user_api_key_header: dict[str, str], + llm_response_in_db: LLMCallResponse, +) -> None: + """Job in SUCCESS state returns full llm_response with usage.""" + + JobCrud(db).update(llm_response_in_db.job_id, JobUpdate(status=JobStatus.SUCCESS)) + + response = client.get( + f"/api/v1/llm/call/{llm_response_in_db.job_id}", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + data = body["data"] + assert data["status"] == "SUCCESS" + assert data["llm_response"] is not None + assert data["llm_response"]["response"]["provider_response_id"] == "resp_abc123" + assert data["llm_response"]["response"]["provider"] == "openai" + assert data["llm_response"]["usage"]["input_tokens"] == 10 + assert data["llm_response"]["usage"]["output_tokens"] == 5 + assert data["llm_response"]["usage"]["total_tokens"] == 15 + + +def test_get_llm_call_failed( + client: TestClient, + db: Session, + user_api_key_header: dict[str, str], + llm_job, +) -> None: + JobCrud(db).update( + llm_job.id, + JobUpdate(status=JobStatus.FAILED, error_message="Provider timeout"), + ) + + response = client.get( + f"/api/v1/llm/call/{llm_job.id}", + headers=user_api_key_header, + ) + + assert response.status_code == 200 + body = response.json() + assert body["success"] is True + assert body["data"]["status"] == "FAILED" + assert body["data"]["error_message"] == "Provider timeout" + assert body["data"]["llm_response"] is None + + +def test_get_llm_call_not_found( + client: TestClient, + user_api_key_header: dict[str, str], +) -> None: + """Non-existent job_id returns 404.""" + + response = client.get( + f"/api/v1/llm/call/{uuid4()}", + headers=user_api_key_header, + ) + + assert response.status_code == 404 diff --git a/backend/app/tests/crud/test_llm.py b/backend/app/tests/crud/test_llm.py index ce6bb2e60..4a188392f 100644 --- a/backend/app/tests/crud/test_llm.py +++ b/backend/app/tests/crud/test_llm.py @@ -1,49 +1,42 @@ +import base64 from uuid import uuid4 import pytest -from sqlmodel import Session, select +from sqlmodel import Session -from app.crud import JobCrud from app.crud.llm import ( create_llm_call, get_llm_call_by_id, get_llm_calls_by_job_id, update_llm_call_response, ) -from app.models import JobType, Project, Organization +from app.models import Project, Organization from app.models.llm import ( ConfigBlob, LLMCallRequest, - LlmCall, QueryParams, ) from app.models.llm.request import ( KaapiCompletionConfig, LLMCallConfig, ) +from app.tests.utils.utils import get_project, get_organization +from app.tests.utils.llm import create_llm_job @pytest.fixture def test_project(db: Session) -> Project: - """Get the first available test project.""" - project = db.exec(select(Project).limit(1)).first() - assert project is not None, "No test project found in seed data" - return project + return get_project(db) @pytest.fixture -def test_organization(db: Session, test_project: Project) -> Organization: - """Get the organization for the test project.""" - org = db.get(Organization, test_project.organization_id) - assert org is not None, "No organization found for test project" - return org +def test_organization(db: Session) -> Organization: + return get_organization(db) @pytest.fixture def test_job(db: Session): - """Create a test job for LLM call tests.""" - crud = JobCrud(db) - return crud.create(job_type=JobType.LLM_API, trace_id="test-llm-trace") + return create_llm_job(db) @pytest.fixture @@ -308,14 +301,15 @@ def test_get_llm_calls_by_job_id( original_provider="openai", ) - llm_calls = get_llm_calls_by_job_id(db, test_job.id) + llm_calls = get_llm_calls_by_job_id(db, test_job.id, test_project.id) assert len(llm_calls) == 3 -def test_get_llm_calls_by_job_id_empty(db: Session) -> None: +def test_get_llm_calls_by_job_id_empty(db: Session, test_project: Project) -> None: """Test fetching LLM calls for a job with no calls.""" fake_job_id = uuid4() - llm_calls = get_llm_calls_by_job_id(db, fake_job_id) + + llm_calls = get_llm_calls_by_job_id(db, fake_job_id, test_project.id) assert llm_calls == [] @@ -421,7 +415,6 @@ def test_update_llm_call_response_with_audio_content( tts_config_blob: ConfigBlob, ) -> None: """Test updating LLM call with audio content calculates size.""" - import base64 request = LLMCallRequest( query=QueryParams(input="Test input"), diff --git a/backend/app/tests/utils/llm.py b/backend/app/tests/utils/llm.py new file mode 100644 index 000000000..8c3368026 --- /dev/null +++ b/backend/app/tests/utils/llm.py @@ -0,0 +1,71 @@ +from sqlmodel import Session + +from app.crud import JobCrud +from app.crud.llm import create_llm_call, update_llm_call_response +from app.models import JobType, Job +from app.models.llm.response import LLMCallResponse +from app.models.llm.request import ( + ConfigBlob, + KaapiCompletionConfig, + LLMCallConfig, + QueryParams, +) +from app.models.llm import LLMCallRequest + + +def create_llm_job(db: Session) -> Job: + """Create a persisted LLM_API job for use in tests.""" + return JobCrud(db).create(job_type=JobType.LLM_API, trace_id="test-llm-trace") + + +def create_llm_call_with_response( + db: Session, + job_id, + project_id: int, + organization_id: int, +) -> LLMCallResponse: + """ + Create a persisted LlmCall with a completed response for use in tests. + + Uses a standard OpenAI text-completion config and fixed response values + so tests can assert against predictable data. + """ + config_blob = ConfigBlob( + completion=KaapiCompletionConfig( + provider="openai", + params={ + "model": "gpt-4o", + "instructions": "You are helpful.", + "temperature": 0.7, + }, + type="text", + ) + ) + + llm_call = create_llm_call( + db, + request=LLMCallRequest( + query=QueryParams(input="What is the capital of France?"), + config=LLMCallConfig(blob=config_blob), + ), + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + resolved_config=config_blob, + original_provider="openai", + ) + + update_llm_call_response( + db, + llm_call_id=llm_call.id, + provider_response_id="resp_abc123", + content={"type": "text", "content": {"format": "text", "value": "Paris"}}, + usage={ + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "reasoning_tokens": None, + }, + ) + + return llm_call