Skip to content

Commit 8fad762

Browse files
committed
fix(cloud-agent): use StrEnum for authorship/source constants and typed RunState model
Address PR review feedback: replace string constants with PrAuthorshipMode and RunSource StrEnum classes, and introduce a RunState Pydantic model for typed access to TaskRun.state fields.
1 parent 94f6c3b commit 8fad762

4 files changed

Lines changed: 70 additions & 51 deletions

File tree

products/tasks/backend/api.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from .services.connection_token import create_sandbox_connection_token
6464
from .stream.redis_stream import TaskRunRedisStream, TaskRunStreamError, get_task_run_stream_key
6565
from .temporal.client import execute_posthog_code_agent_relay_workflow, execute_task_processing_workflow
66-
from .temporal.process_task.utils import PR_AUTHORSHIP_MODE_USER, cache_github_user_token
66+
from .temporal.process_task.utils import PrAuthorshipMode, cache_github_user_token, parse_run_state
6767

6868
logger = logging.getLogger(__name__)
6969

@@ -261,30 +261,26 @@ def run(self, request, pk=None, **kwargs):
261261
return Response({"detail": "Invalid resume_from_run_id"}, status=400)
262262

263263
# Derive snapshot_external_id from the validated previous run
264-
snapshot_ext_id = (previous_run.state or {}).get("snapshot_external_id")
264+
prev_state = parse_run_state(previous_run.state)
265265
extra_state = {
266266
"resume_from_run_id": str(resume_from_run_id),
267267
}
268268
if pending_user_message is not None:
269269
extra_state["pending_user_message"] = pending_user_message
270-
if snapshot_ext_id:
271-
extra_state["snapshot_external_id"] = snapshot_ext_id
270+
if prev_state.snapshot_external_id:
271+
extra_state["snapshot_external_id"] = prev_state.snapshot_external_id
272272

273-
prev_sandbox_env_id = (previous_run.state or {}).get("sandbox_environment_id")
274-
if prev_sandbox_env_id and sandbox_environment_id is None:
275-
sandbox_environment_id = prev_sandbox_env_id
273+
if prev_state.sandbox_environment_id and sandbox_environment_id is None:
274+
sandbox_environment_id = prev_state.sandbox_environment_id
276275

277-
previous_state = previous_run.state or {}
278276
if pr_authorship_mode is None:
279-
pr_authorship_mode = previous_state.get("pr_authorship_mode")
277+
pr_authorship_mode = prev_state.pr_authorship_mode
280278
if run_source is None:
281-
run_source = previous_state.get("run_source")
279+
run_source = prev_state.run_source
282280
if signal_report_id is None:
283-
signal_report_id = previous_state.get("signal_report_id")
284-
if branch is None:
285-
previous_base_branch = previous_state.get("pr_base_branch")
286-
if isinstance(previous_base_branch, str):
287-
branch = previous_base_branch
281+
signal_report_id = prev_state.signal_report_id
282+
if branch is None and prev_state.pr_base_branch is not None:
283+
branch = prev_state.pr_base_branch
288284

289285
for key, value in {
290286
"pr_base_branch": branch,
@@ -297,7 +293,7 @@ def run(self, request, pk=None, **kwargs):
297293
extra_state[key] = value
298294

299295
# Only require a user token when the task has a repo (no-repo cloud runs skip GitHub operations)
300-
if pr_authorship_mode == PR_AUTHORSHIP_MODE_USER and task.repository and not github_user_token:
296+
if pr_authorship_mode == PrAuthorshipMode.USER and task.repository and not github_user_token:
301297
return Response({"detail": "github_user_token is required for user-authored cloud runs"}, status=400)
302298

303299
if sandbox_environment_id is not None:
@@ -322,7 +318,7 @@ def run(self, request, pk=None, **kwargs):
322318

323319
task_run = task.create_run(mode=mode, branch=branch, extra_state=extra_state)
324320

325-
if github_user_token and pr_authorship_mode == PR_AUTHORSHIP_MODE_USER:
321+
if github_user_token and pr_authorship_mode == PrAuthorshipMode.USER:
326322
cache_github_user_token(str(task_run.id), github_user_token)
327323

328324
logger.info(f"Triggering workflow for task {task.id}, run {task_run.id}")
@@ -862,16 +858,15 @@ def connection_token(self, request, pk=None, **kwargs):
862858
)
863859
def command(self, request, pk=None, **kwargs):
864860
task_run = cast(TaskRun, self.get_object())
865-
state = task_run.state or {}
861+
run_state = parse_run_state(task_run.state)
866862

867-
sandbox_url = state.get("sandbox_url")
868-
if not sandbox_url:
863+
if not run_state.sandbox_url:
869864
return Response(
870865
ErrorResponseSerializer({"error": "No active sandbox for this task run"}).data,
871866
status=status.HTTP_400_BAD_REQUEST,
872867
)
873868

874-
if not self._is_valid_sandbox_url(sandbox_url):
869+
if not self._is_valid_sandbox_url(run_state.sandbox_url):
875870
logger.warning(f"Blocked request to disallowed sandbox URL for task run {task_run.id}")
876871
return Response(
877872
ErrorResponseSerializer({"error": "Invalid sandbox URL"}).data,
@@ -884,8 +879,6 @@ def command(self, request, pk=None, **kwargs):
884879
distinct_id=request.user.distinct_id,
885880
)
886881

887-
sandbox_connect_token = state.get("sandbox_connect_token")
888-
889882
command_payload: dict = {
890883
"jsonrpc": request.validated_data["jsonrpc"],
891884
"method": request.validated_data["method"],
@@ -897,9 +890,9 @@ def command(self, request, pk=None, **kwargs):
897890

898891
try:
899892
agent_response = self._proxy_command_to_agent_server(
900-
sandbox_url=sandbox_url,
893+
sandbox_url=run_state.sandbox_url,
901894
connection_token=connection_token,
902-
sandbox_connect_token=sandbox_connect_token,
895+
sandbox_connect_token=run_state.sandbox_connect_token,
903896
payload=command_payload,
904897
)
905898

products/tasks/backend/serializers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,7 @@
1010

1111
from .models import SandboxEnvironment, Task, TaskRun
1212
from .services.title_generator import generate_task_title
13-
from .temporal.process_task.utils import (
14-
PR_AUTHORSHIP_MODE_BOT,
15-
PR_AUTHORSHIP_MODE_USER,
16-
RUN_SOURCE_MANUAL,
17-
RUN_SOURCE_SIGNAL_REPORT,
18-
)
13+
from .temporal.process_task.utils import PrAuthorshipMode, RunSource
1914

2015
PRESIGNED_URL_CACHE_TTL = 55 * 60 # 55 minutes (less than 1 hour URL expiry)
2116

@@ -363,8 +358,8 @@ class ConnectionTokenResponseSerializer(serializers.Serializer):
363358
class TaskRunCreateRequestSerializer(serializers.Serializer):
364359
"""Request body for creating a new task run"""
365360

366-
PR_AUTHORSHIP_MODE_CHOICES = [PR_AUTHORSHIP_MODE_USER, PR_AUTHORSHIP_MODE_BOT]
367-
RUN_SOURCE_CHOICES = [RUN_SOURCE_MANUAL, RUN_SOURCE_SIGNAL_REPORT]
361+
PR_AUTHORSHIP_MODE_CHOICES = [mode.value for mode in PrAuthorshipMode]
362+
RUN_SOURCE_CHOICES = [source.value for source in RunSource]
368363

369364
mode = serializers.ChoiceField(
370365
choices=["interactive", "background"],

products/tasks/backend/temporal/process_task/activities/get_sandbox_for_repository.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
get_sandbox_api_url,
2121
get_sandbox_github_token,
2222
get_sandbox_name_for_task,
23+
parse_run_state,
2324
)
2425

2526
from .get_task_processing_context import TaskProcessingContext
@@ -158,14 +159,15 @@ def get_sandbox_for_repository(input: GetSandboxForRepositoryInput) -> GetSandbo
158159

159160
environment_variables.update(get_git_identity_env_vars(task, ctx.state))
160161

162+
run_state = parse_run_state(ctx.state)
163+
161164
# Set resume run ID independently of snapshot so conversation history
162165
# can be rebuilt from logs even when the filesystem snapshot has expired.
163-
resume_from_run_id = (ctx.state or {}).get("resume_from_run_id", "")
164-
if resume_from_run_id:
165-
environment_variables["POSTHOG_RESUME_RUN_ID"] = resume_from_run_id
166+
if run_state.resume_from_run_id:
167+
environment_variables["POSTHOG_RESUME_RUN_ID"] = run_state.resume_from_run_id
166168

167169
# Check for resume snapshot (takes priority over integration-level snapshots)
168-
resume_snapshot_ext_id = (ctx.state or {}).get("snapshot_external_id")
170+
resume_snapshot_ext_id = run_state.snapshot_external_id
169171
if resume_snapshot_ext_id:
170172
used_snapshot = True
171173

products/tasks/backend/temporal/process_task/utils.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,54 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass, field
4+
from enum import StrEnum
45
from typing import TYPE_CHECKING, Any, Optional
56
from urllib.parse import urlparse
67

78
from django.conf import settings
89
from django.core.cache import cache
910

11+
from pydantic import BaseModel
12+
1013
from posthog.models.integration import GitHubIntegration, Integration
1114
from posthog.temporal.oauth import PosthogMcpScopes, has_write_scopes
1215

1316
if TYPE_CHECKING:
1417
from products.tasks.backend.models import Task
1518

1619

17-
PR_AUTHORSHIP_MODE_USER = "user"
18-
PR_AUTHORSHIP_MODE_BOT = "bot"
19-
RUN_SOURCE_MANUAL = "manual"
20-
RUN_SOURCE_SIGNAL_REPORT = "signal_report"
20+
class PrAuthorshipMode(StrEnum):
21+
USER = "user"
22+
BOT = "bot"
23+
24+
25+
class RunSource(StrEnum):
26+
MANUAL = "manual"
27+
SIGNAL_REPORT = "signal_report"
28+
29+
30+
class RunState(BaseModel, extra="allow"):
31+
pr_authorship_mode: PrAuthorshipMode | None = None
32+
pr_base_branch: str | None = None
33+
run_source: RunSource | None = None
34+
signal_report_id: str | None = None
35+
resume_from_run_id: str | None = None
36+
snapshot_external_id: str | None = None
37+
sandbox_id: str | None = None
38+
sandbox_url: str | None = None
39+
sandbox_connect_token: str | None = None
40+
sandbox_environment_id: str | None = None
41+
pending_user_message: str | None = None
42+
pending_user_message_ts: str | None = None
43+
slack_thread_url: str | None = None
44+
interaction_origin: str | None = None
45+
slack_sent_relay_ids: list[str] | None = None
46+
47+
48+
def parse_run_state(state: dict[str, Any] | None) -> RunState:
49+
return RunState.model_validate(state or {})
50+
51+
2152
GITHUB_USER_TOKEN_CACHE_TTL_SECONDS = 6 * 60 * 60
2253

2354

@@ -119,8 +150,8 @@ def get_cached_github_user_token(run_id: str) -> str | None:
119150
def get_sandbox_github_token(
120151
github_integration_id: int | None, *, run_id: str, state: dict[str, Any] | None = None
121152
) -> str | None:
122-
authorship_mode = (state or {}).get("pr_authorship_mode")
123-
if authorship_mode == PR_AUTHORSHIP_MODE_USER:
153+
run_state = parse_run_state(state)
154+
if run_state.pr_authorship_mode == PrAuthorshipMode.USER:
124155
github_user_token = get_cached_github_user_token(run_id)
125156
if not github_user_token:
126157
raise ValueError(
@@ -184,22 +215,20 @@ def build_sandbox_environment_variables(
184215
return env_vars
185216

186217

187-
def get_pr_authorship_mode(task: Task, state: dict[str, Any] | None = None) -> str:
218+
def get_pr_authorship_mode(task: Task, state: dict[str, Any] | None = None) -> PrAuthorshipMode:
188219
"""Return the effective PR authorship mode for a run.
189220
190221
Newer cloud runs store the mode in ``TaskRun.state``. Older user-created
191222
runs fall back to user authorship so they still get a human git identity.
192223
"""
193224
from products.tasks.backend.models import Task as TaskModel
194225

195-
mode = (state or {}).get("pr_authorship_mode")
196-
if mode in {PR_AUTHORSHIP_MODE_USER, PR_AUTHORSHIP_MODE_BOT}:
197-
return mode
226+
run_state = parse_run_state(state)
227+
if run_state.pr_authorship_mode is not None:
228+
return run_state.pr_authorship_mode
198229

199230
return (
200-
PR_AUTHORSHIP_MODE_USER
201-
if task.origin_product == TaskModel.OriginProduct.USER_CREATED
202-
else PR_AUTHORSHIP_MODE_BOT
231+
PrAuthorshipMode.USER if task.origin_product == TaskModel.OriginProduct.USER_CREATED else PrAuthorshipMode.BOT
203232
)
204233

205234

@@ -210,7 +239,7 @@ def get_git_identity_env_vars(task: Task, state: dict[str, Any] | None = None) -
210239
Bot-authored runs fall back to the Dockerfile defaults ("PostHog Code" /
211240
code@posthog.com).
212241
"""
213-
if get_pr_authorship_mode(task, state) != PR_AUTHORSHIP_MODE_USER:
242+
if get_pr_authorship_mode(task, state) != PrAuthorshipMode.USER:
214243
return {}
215244

216245
user = task.created_by

0 commit comments

Comments
 (0)