diff --git a/products/tasks/backend/api.py b/products/tasks/backend/api.py index 4a4c64df1d25..d20a2159b15a 100644 --- a/products/tasks/backend/api.py +++ b/products/tasks/backend/api.py @@ -14,6 +14,7 @@ from django.utils import timezone import requests as http_requests +import jsonschema import posthoganalytics from drf_spectacular.utils import OpenApiResponse, extend_schema from rest_framework import status, viewsets @@ -58,6 +59,7 @@ TaskRunRelayMessageRequestSerializer, TaskRunRelayMessageResponseSerializer, TaskRunSessionLogsQuerySerializer, + TaskRunSetOutputRequestSerializer, TaskRunUpdateSerializer, TaskSerializer, ) @@ -539,7 +541,7 @@ def perform_create(self, serializer): serializer.save(team=self.team, task=task) @validated_request( - request_serializer=None, + request_serializer=TaskRunSetOutputRequestSerializer, responses={ 200: OpenApiResponse(response=TaskRunDetailSerializer, description="Run with updated output"), 404: OpenApiResponse(description="Run not found"), @@ -555,17 +557,20 @@ def perform_create(self, serializer): ) def set_output(self, request, pk=None, **kwargs): task_run = cast(TaskRun, self.get_object()) + task = cast(Task, task_run.task) + output_data = request.validated_data["output"] - output_data = request.data.get("output", {}) - if not isinstance(output_data, dict): - return Response( - ErrorResponseSerializer({"error": "output must be a dictionary"}).data, - status=status.HTTP_400_BAD_REQUEST, - ) - - # TODO: Validate output data according to schema for the task type. + if task.json_schema: + try: + jsonschema.validate(instance=output_data, schema=task.json_schema) + except jsonschema.ValidationError as e: + return Response( + ErrorResponseSerializer({"error": f"Output validation error: {e.message}"}).data, + status=status.HTTP_400_BAD_REQUEST, + ) task_run.output = output_data task_run.save(update_fields=["output", "updated_at"]) + self._signal_workflow_completion(task_run, TaskRun.Status.COMPLETED, None) task_run.publish_stream_state_event() self._post_slack_update_for_pr(task_run) diff --git a/products/tasks/backend/models.py b/products/tasks/backend/models.py index 19c79d5c8597..6266bc102d60 100644 --- a/products/tasks/backend/models.py +++ b/products/tasks/backend/models.py @@ -6,6 +6,11 @@ import secrets from typing import TYPE_CHECKING, Any, Literal, Optional +from django.db.models.signals import post_save +from django.dispatch import receiver + +from pydantic import BaseModel + if TYPE_CHECKING: from products.slack_app.backend.slack_thread import SlackThreadContext @@ -35,6 +40,12 @@ LogLevel = Literal["debug", "info", "warn", "error"] +def resolve_schema(schema: type[BaseModel] | dict) -> dict: + if isinstance(schema, dict): + return schema + return schema.model_json_schema() + + class Task(DeletedMetaFields, models.Model): class OriginProduct(models.TextChoices): ERROR_TRACKING = "error_tracking", "Error Tracking" @@ -244,6 +255,7 @@ def create_and_run( signal_report_id: str | None = None, sandbox_environment_id: str | None = None, internal: bool = False, + output_schema: type[BaseModel] | dict | None = None, ) -> "Task": from products.tasks.backend.temporal.client import execute_task_processing_workflow @@ -270,6 +282,7 @@ def create_and_run( github_integration=github_integration, repository=repository, internal=internal, + json_schema=resolve_schema(output_schema) if output_schema else None, **({"signal_report_id": signal_report_id} if signal_report_id else {}), ) @@ -521,6 +534,20 @@ def mark_completed(self): {"duration_seconds": self._duration_seconds()}, ) + def track_structured_result(self): + """Track a structured result event with properties from the run output.""" + if not self.output: + return + + try: + self.capture_event("task_run_structured_result", {"result": self.output}) + except Exception as e: + logger.warning( + "task_run.track_structured_result_failed", + task_run_id=str(self.id), + error=str(e), + ) + def mark_failed(self, error: str): """Mark the progress as failed with an error message.""" self.status = self.Status.FAILED @@ -852,3 +879,21 @@ class Meta: def __str__(self): return f"{self.user} redeemed {self.invite_code}" + + +@receiver(post_save, sender=TaskRun) +def track_task_run_completion(sender, instance: TaskRun, created: bool, **kwargs): + try: + if ( + not created + and instance.status == TaskRun.Status.COMPLETED + and instance.output + and instance.task.json_schema + ): + instance.track_structured_result() + except Exception as e: + logger.warning( + "task_run.track_task_run_completion_failed", + task_run_id=str(instance.id), + error=str(e), + ) diff --git a/products/tasks/backend/serializers.py b/products/tasks/backend/serializers.py index c391786f2f83..f143d9fec3c7 100644 --- a/products/tasks/backend/serializers.py +++ b/products/tasks/backend/serializers.py @@ -218,6 +218,12 @@ def update(self, instance, validated_data): return super().update(instance, validated_data) +class TaskRunSetOutputRequestSerializer(serializers.Serializer): + output = serializers.JSONField( + help_text="Output data from the run. Validated against the task's json_schema if one is set." + ) + + class ErrorResponseSerializer(serializers.Serializer): error = serializers.CharField(help_text="Error message") diff --git a/products/tasks/backend/temporal/process_task/activities/get_task_processing_context.py b/products/tasks/backend/temporal/process_task/activities/get_task_processing_context.py index 23302bbbc1ce..b735dd257a79 100644 --- a/products/tasks/backend/temporal/process_task/activities/get_task_processing_context.py +++ b/products/tasks/backend/temporal/process_task/activities/get_task_processing_context.py @@ -38,6 +38,7 @@ class TaskProcessingContext: _branch: str | None = None sandbox_environment_name: str | None = None allowed_domains: list[str] | None = None + json_schema: dict | None = None @property def mode(self) -> str: @@ -164,4 +165,5 @@ def get_task_processing_context(input: GetTaskProcessingContextInput) -> TaskPro _branch=task_run.branch, sandbox_environment_name=sandbox_environment_name, allowed_domains=allowed_domains, + json_schema=task.json_schema, ) diff --git a/products/tasks/frontend/generated/api.schemas.ts b/products/tasks/frontend/generated/api.schemas.ts index 8a6d5df701e2..87d56f4e2651 100644 --- a/products/tasks/frontend/generated/api.schemas.ts +++ b/products/tasks/frontend/generated/api.schemas.ts @@ -604,6 +604,11 @@ export interface TaskRunRelayMessageResponseApi { relay_id?: string } +export interface PatchedTaskRunSetOutputRequestApi { + /** Output data from the run. Validated against the task's json_schema if one is set. */ + output?: unknown +} + /** * * `needs_setup` - needs_setup * `detected` - detected diff --git a/products/tasks/frontend/generated/api.ts b/products/tasks/frontend/generated/api.ts index 202b61e9fb31..11146aaae526 100644 --- a/products/tasks/frontend/generated/api.ts +++ b/products/tasks/frontend/generated/api.ts @@ -15,6 +15,7 @@ import type { PaginatedTaskListApi, PaginatedTaskRunDetailListApi, PatchedTaskApi, + PatchedTaskRunSetOutputRequestApi, PatchedTaskRunUpdateApi, RepositoryReadinessResponseApi, SandboxEnvironmentApi, @@ -556,11 +557,14 @@ export const tasksRunsSetOutputPartialUpdate = async ( projectId: string, taskId: string, id: string, + patchedTaskRunSetOutputRequestApi: PatchedTaskRunSetOutputRequestApi, options?: RequestInit ): Promise => { return apiMutator(getTasksRunsSetOutputPartialUpdateUrl(projectId, taskId, id), { ...options, method: 'PATCH', + headers: { 'Content-Type': 'application/json', ...options?.headers }, + body: JSON.stringify(patchedTaskRunSetOutputRequestApi), }) } diff --git a/services/mcp/src/api/generated.ts b/services/mcp/src/api/generated.ts index 69269ff4c88d..72b1416613c6 100644 --- a/services/mcp/src/api/generated.ts +++ b/services/mcp/src/api/generated.ts @@ -26007,6 +26007,11 @@ export namespace Schemas { readonly created_by?: UserBasic; } + export interface PatchedTaskRunSetOutputRequest { + /** Output data from the run. Validated against the task's json_schema if one is set. */ + output?: unknown; + } + /** * * `not_started` - not_started * `queued` - queued