Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions products/tasks/backend/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from posthog.event_usage import groups
from posthog.permissions import APIScopePermission
from posthog.rate_limit import CodeInviteThrottle
from posthog.renderers import ServerSentEventRenderer
from posthog.storage import object_storage

from ee.hogai.utils.aio import async_to_sync
Expand Down Expand Up @@ -441,6 +442,7 @@ def partial_update(self, request, *args, **kwargs):
task_run.completed_at = timezone.now()

task_run.save()
task_run.publish_stream_state_event()

# Signal Temporal and post Slack updates after commit to avoid
# holding the row lock during external calls.
Expand Down Expand Up @@ -564,6 +566,7 @@ def set_output(self, request, pk=None, **kwargs):
# TODO: Validate output data according to schema for the task type.
task_run.output = output_data
task_run.save(update_fields=["output", "updated_at"])
task_run.publish_stream_state_event()
self._post_slack_update_for_pr(task_run)

return Response(TaskRunDetailSerializer(task_run, context=self.get_serializer_context()).data)
Expand Down Expand Up @@ -1022,6 +1025,8 @@ def session_logs(self, request, pk=None, **kwargs):
response = JsonResponse([], safe=False)
response["X-Total-Count"] = "0"
response["X-Filtered-Count"] = "0"
response["X-Matching-Count"] = "0"
response["X-Has-More"] = "false"
response["Cache-Control"] = "no-cache"
response["Server-Timing"] = timer.to_header_string()
return response
Expand All @@ -1045,6 +1050,7 @@ def session_logs(self, request, pk=None, **kwargs):
event_types_str = params.get("event_types")
exclude_types_str = params.get("exclude_types")
limit = params.get("limit", 1000)
offset = params.get("offset", 0)

event_types = {t.strip() for t in event_types_str.split(",") if t.strip()} if event_types_str else None
exclude_types = {t.strip() for t in exclude_types_str.split(",") if t.strip()} if exclude_types_str else None
Expand Down Expand Up @@ -1074,37 +1080,58 @@ def session_logs(self, request, pk=None, **kwargs):

filtered.append(entry)

if len(filtered) >= limit:
break
matching_count = len(filtered)
page = filtered[offset : offset + limit]
has_more = offset + len(page) < matching_count

response = JsonResponse(filtered, safe=False)
response = JsonResponse(page, safe=False)
response["X-Total-Count"] = str(total_count)
response["X-Filtered-Count"] = str(len(filtered))
response["X-Filtered-Count"] = str(matching_count)
response["X-Matching-Count"] = str(matching_count)
response["X-Has-More"] = "true" if has_more else "false"
response["Cache-Control"] = "no-cache"
response["Server-Timing"] = timer.to_header_string()
return response

@action(detail=True, methods=["get"], url_path="stream", required_scopes=["task:read"])
@staticmethod
def _format_sse_event(data: dict, *, event_id: str | None = None, event_name: str | None = None) -> bytes:
parts: list[str] = []
if event_name:
parts.append(f"event: {event_name}")
if event_id:
parts.append(f"id: {event_id}")
parts.append(f"data: {json.dumps(data)}")
return ("\n".join(parts) + "\n\n").encode()

@action(
detail=True,
methods=["get"],
url_path="stream",
required_scopes=["task:read"],
renderer_classes=[ServerSentEventRenderer],
)
def stream(self, request, pk=None, **kwargs):
task_run = cast(TaskRun, self.get_object())
stream_key = get_task_run_stream_key(str(task_run.id))
last_event_id = request.headers.get("Last-Event-ID")
start_latest = request.GET.get("start") == "latest"
format_sse_event = self._format_sse_event

async def async_stream() -> AsyncGenerator[bytes, None]:
redis_stream = TaskRunRedisStream(stream_key)
if not await redis_stream.wait_for_stream():
yield b'event: error\ndata: {"error":"Stream not available"}\n\n'
yield format_sse_event({"error": "Stream not available"}, event_name="error")
return

start_id = last_event_id or "0"
if not last_event_id and start_latest:
start_id = await redis_stream.get_latest_stream_id() or "0"
try:
async for event in redis_stream.read_stream():
yield f"data: {json.dumps(event)}\n\n".encode()
async for event_id, event in redis_stream.read_stream_entries(start_id=start_id):
yield format_sse_event(event, event_id=event_id)
except TaskRunStreamError as e:
logger.error(
"TaskRunRedisStream error for stream %s: %s",
stream_key,
e,
exc_info=True,
)
yield b'event: error\ndata: {"error": "Stream error"}\n\n'
logger.error("TaskRunRedisStream error for stream %s: %s", stream_key, e, exc_info=True)
yield format_sse_event({"error": str(e)}, event_name="error")

return StreamingHttpResponse(
async_stream() if settings.SERVER_GATEWAY_INTERFACE == "ASGI" else async_to_sync(lambda: async_stream()),
Expand Down
28 changes: 27 additions & 1 deletion products/tasks/backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import uuid
import string
import secrets
from typing import TYPE_CHECKING, Literal, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional

if TYPE_CHECKING:
from products.slack_app.backend.slack_thread import SlackThreadContext
Expand All @@ -28,6 +28,7 @@
from posthog.temporal.oauth import PosthogMcpScopes

from products.tasks.backend.constants import DEFAULT_TRUSTED_DOMAINS
from products.tasks.backend.stream.redis_stream import publish_task_run_stream_event

logger = structlog.get_logger(__name__)

Expand Down Expand Up @@ -199,6 +200,7 @@ def create_run(
state=state,
branch=branch,
)
task_run.publish_stream_state_event()
self.capture_event(
"task_run_created",
{
Expand Down Expand Up @@ -513,6 +515,7 @@ def mark_completed(self):
self.status = self.Status.COMPLETED
self.completed_at = timezone.now()
self.save(update_fields=["status", "completed_at"])
self.publish_stream_state_event()
self.capture_event(
"task_run_completed",
{"duration_seconds": self._duration_seconds()},
Expand All @@ -524,6 +527,7 @@ def mark_failed(self, error: str):
self.error_message = error
self.completed_at = timezone.now()
self.save(update_fields=["status", "error_message", "completed_at"])
self.publish_stream_state_event()
self.capture_event(
"task_run_failed",
{
Expand All @@ -532,6 +536,26 @@ def mark_failed(self, error: str):
},
)

def build_stream_state_event(self) -> dict[str, Any]:
return {
"type": "task_run_state",
"run_id": str(self.id),
"task_id": str(self.task_id),
"status": self.status,
"stage": self.stage,
"output": self.output,
"branch": self.branch,
"error_message": self.error_message,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
}

def publish_stream_event(self, event: dict[str, Any]) -> None:
publish_task_run_stream_event(str(self.id), event)

def publish_stream_state_event(self) -> None:
self.publish_stream_event(self.build_stream_state_event())

def emit_console_event(self, level: LogLevel, message: str) -> None:
"""Emit a console-style log event in ACP notification format."""
event = {
Expand All @@ -548,6 +572,7 @@ def emit_console_event(self, level: LogLevel, message: str) -> None:
},
}
self.append_log([event])
self.publish_stream_event(event)

def emit_sandbox_output(self, stdout: str, stderr: str, exit_code: int) -> None:
"""Emit sandbox execution output as ACP notification."""
Expand All @@ -566,6 +591,7 @@ def emit_sandbox_output(self, stdout: str, stderr: str, exit_code: int) -> None:
},
}
self.append_log([event])
self.publish_stream_event(event)

@property
def is_terminal(self) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions products/tasks/backend/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,12 @@ class TaskRunSessionLogsQuerySerializer(serializers.Serializer):
max_value=5000,
help_text="Maximum number of entries to return (default 1000, max 5000)",
)
offset = serializers.IntegerField(
required=False,
default=0,
min_value=0,
help_text="Zero-based offset into the filtered log entries",
)


class SandboxEnvironmentSerializer(serializers.ModelSerializer):
Expand Down
67 changes: 60 additions & 7 deletions products/tasks/backend/stream/redis_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,28 @@

import structlog
import redis.exceptions as redis_exceptions
from asgiref.sync import async_to_sync

from posthog.redis import get_async_client

logger = structlog.get_logger(__name__)

TASK_RUN_STREAM_MAX_LENGTH = 2000
# Keep enough live history for users who open an in-progress run late while
# still bounding Redis growth for streams with a one-hour TTL.
TASK_RUN_STREAM_MAX_LENGTH = 20_000
TASK_RUN_STREAM_TIMEOUT = 60 * 60 # 60 minutes
TASK_RUN_STREAM_PREFIX = "task-run-stream:"
TASK_RUN_STREAM_READ_COUNT = 16

DATA_KEY = b"data"


def _normalize_stream_id(stream_id: str | bytes) -> str:
if isinstance(stream_id, bytes):
return stream_id.decode("utf-8")
return stream_id


class TaskRunStreamError(Exception):
pass

Expand Down Expand Up @@ -76,15 +85,33 @@ async def wait_for_stream(self) -> bool:
await asyncio.sleep(delay)
delay = min(delay + delay_increment, max_delay)

async def get_latest_stream_id(self) -> str | None:
"""Return the latest stream ID if the stream has any events."""
messages = await self._redis_client.xrevrange(self._stream_key, count=1)
if not messages:
return None
stream_id, _message = messages[0]
return _normalize_stream_id(stream_id)

async def read_stream(
self,
start_id: str = "0",
block_ms: int = 100,
count: Optional[int] = TASK_RUN_STREAM_READ_COUNT,
) -> AsyncGenerator[dict, None]:
async for _stream_id, data in self.read_stream_entries(start_id=start_id, block_ms=block_ms, count=count):
yield data

async def read_stream_entries(
self,
start_id: str = "0",
block_ms: int = 100,
count: Optional[int] = TASK_RUN_STREAM_READ_COUNT,
) -> AsyncGenerator[tuple[str, dict], None]:
"""Read events from the Redis stream.

Yields parsed JSON dicts. Stops when a complete sentinel is received.
Yields Redis stream IDs and parsed JSON dicts.
Stops when a complete sentinel is received.
Raises TaskRunStreamError on error sentinel or timeout.
"""
current_id = start_id
Expand All @@ -106,7 +133,8 @@ async def read_stream(

for _, stream_messages in messages:
for stream_id, message in stream_messages:
current_id = stream_id
normalized_stream_id = _normalize_stream_id(stream_id)
current_id = normalized_stream_id
raw = message.get(DATA_KEY, b"")
data = json.loads(raw)

Expand All @@ -117,7 +145,7 @@ async def read_stream(
elif status == "error":
raise TaskRunStreamError(data.get("error", "Unknown stream error"))
else:
yield data
yield normalized_stream_id, data

except (TaskRunStreamError, GeneratorExit):
raise
Expand All @@ -128,15 +156,22 @@ async def read_stream(
except redis_exceptions.RedisError:
raise TaskRunStreamError("Stream read error")

async def write_event(self, event: dict) -> None:
"""Write a single event to the stream."""
async def write_event(self, event: dict) -> str:
"""Write a single event to the stream.

Refreshes TTL on every write (sliding window) so long-running tasks
don't expire mid-stream. This is especially important for the sync
publish path (publish_task_run_stream_event) which bypasses initialize().
"""
raw = json.dumps(event)
await self._redis_client.xadd(
stream_id = await self._redis_client.xadd(
self._stream_key,
{DATA_KEY: raw},
maxlen=self._max_length,
approximate=True,
)
await self._redis_client.expire(self._stream_key, self._timeout)
return _normalize_stream_id(stream_id)

async def mark_complete(self) -> None:
"""Write a completion sentinel to signal end of stream."""
Expand All @@ -153,3 +188,21 @@ async def delete_stream(self) -> bool:
except Exception:
logger.exception("task_run_stream_delete_failed", stream_key=self._stream_key)
return False


def publish_task_run_stream_event(run_id: str, event: dict) -> str | None:
"""Synchronously publish a task-run event to Redis.

This is intended for sync Django model/view code that needs to mirror
user-visible task-run events into the live SSE stream.
"""

async def _publish() -> str:
redis_stream = TaskRunRedisStream(get_task_run_stream_key(run_id))
return await redis_stream.write_event(event)

try:
return async_to_sync(_publish)()
except Exception:
logger.exception("task_run_stream_publish_failed", run_id=run_id)
return None
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from unittest.mock import patch

from asgiref.sync import async_to_sync

Expand Down Expand Up @@ -45,6 +46,15 @@ def test_updates_error_message(self, activity_environment, test_task_run):
test_task_run.refresh_from_db()
assert test_task_run.error_message == error_msg

@pytest.mark.django_db
@patch("products.tasks.backend.models.TaskRun.publish_stream_state_event")
def test_publishes_stream_state_event(self, mock_publish_stream_state_event, activity_environment, test_task_run):
input_data = UpdateTaskRunStatusInput(run_id=str(test_task_run.id), status=TaskRun.Status.IN_PROGRESS)

async_to_sync(activity_environment.run)(update_task_run_status, input_data)

mock_publish_stream_state_event.assert_called_once()

@pytest.mark.django_db
def test_handles_non_existent_task_run(self, activity_environment):
non_existent_run_id = "550e8400-e29b-41d4-a716-446655440000"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def update_task_run_status(input: UpdateTaskRunStatusInput) -> None:
task_run.completed_at = timezone.now()

task_run.save(update_fields=["status", "error_message", "completed_at", "updated_at"])
task_run.publish_stream_state_event()

log_with_activity_context(
"Task run status updated",
Expand Down
Loading
Loading