6363from .services .connection_token import create_sandbox_connection_token
6464from .stream .redis_stream import TaskRunRedisStream , TaskRunStreamError , get_task_run_stream_key
6565from .temporal .client import execute_posthog_code_agent_relay_workflow , execute_task_processing_workflow
66- from .temporal .process_task .utils import PR_AUTHORSHIP_MODE_USER , cache_github_user_token
66+ from .temporal .process_task .utils import PrAuthorshipMode , cache_github_user_token , parse_run_state
6767
6868logger = logging .getLogger (__name__ )
6969
@@ -261,30 +261,26 @@ def run(self, request, pk=None, **kwargs):
261261 return Response ({"detail" : "Invalid resume_from_run_id" }, status = 400 )
262262
263263 # Derive snapshot_external_id from the validated previous run
264- snapshot_ext_id = (previous_run .state or {}). get ( "snapshot_external_id" )
264+ prev_state = parse_run_state (previous_run .state )
265265 extra_state = {
266266 "resume_from_run_id" : str (resume_from_run_id ),
267267 }
268268 if pending_user_message is not None :
269269 extra_state ["pending_user_message" ] = pending_user_message
270- if snapshot_ext_id :
271- extra_state ["snapshot_external_id" ] = snapshot_ext_id
270+ if prev_state . snapshot_external_id :
271+ extra_state ["snapshot_external_id" ] = prev_state . snapshot_external_id
272272
273- prev_sandbox_env_id = (previous_run .state or {}).get ("sandbox_environment_id" )
274- if prev_sandbox_env_id and sandbox_environment_id is None :
275- sandbox_environment_id = prev_sandbox_env_id
273+ if prev_state .sandbox_environment_id and sandbox_environment_id is None :
274+ sandbox_environment_id = prev_state .sandbox_environment_id
276275
277- previous_state = previous_run .state or {}
278276 if pr_authorship_mode is None :
279- pr_authorship_mode = previous_state . get ( " pr_authorship_mode" )
277+ pr_authorship_mode = prev_state . pr_authorship_mode
280278 if run_source is None :
281- run_source = previous_state . get ( " run_source" )
279+ run_source = prev_state . run_source
282280 if signal_report_id is None :
283- signal_report_id = previous_state .get ("signal_report_id" )
284- if branch is None :
285- previous_base_branch = previous_state .get ("pr_base_branch" )
286- if isinstance (previous_base_branch , str ):
287- branch = previous_base_branch
281+ signal_report_id = prev_state .signal_report_id
282+ if branch is None and prev_state .pr_base_branch is not None :
283+ branch = prev_state .pr_base_branch
288284
289285 for key , value in {
290286 "pr_base_branch" : branch ,
@@ -297,7 +293,7 @@ def run(self, request, pk=None, **kwargs):
297293 extra_state [key ] = value
298294
299295 # Only require a user token when the task has a repo (no-repo cloud runs skip GitHub operations)
300- if pr_authorship_mode == PR_AUTHORSHIP_MODE_USER and task .repository and not github_user_token :
296+ if pr_authorship_mode == PrAuthorshipMode . USER and task .repository and not github_user_token :
301297 return Response ({"detail" : "github_user_token is required for user-authored cloud runs" }, status = 400 )
302298
303299 if sandbox_environment_id is not None :
@@ -322,7 +318,7 @@ def run(self, request, pk=None, **kwargs):
322318
323319 task_run = task .create_run (mode = mode , branch = branch , extra_state = extra_state )
324320
325- if github_user_token and pr_authorship_mode == PR_AUTHORSHIP_MODE_USER :
321+ if github_user_token and pr_authorship_mode == PrAuthorshipMode . USER :
326322 cache_github_user_token (str (task_run .id ), github_user_token )
327323
328324 logger .info (f"Triggering workflow for task { task .id } , run { task_run .id } " )
@@ -862,16 +858,15 @@ def connection_token(self, request, pk=None, **kwargs):
862858 )
863859 def command (self , request , pk = None , ** kwargs ):
864860 task_run = cast (TaskRun , self .get_object ())
865- state = task_run .state or {}
861+ run_state = parse_run_state ( task_run .state )
866862
867- sandbox_url = state .get ("sandbox_url" )
868- if not sandbox_url :
863+ if not run_state .sandbox_url :
869864 return Response (
870865 ErrorResponseSerializer ({"error" : "No active sandbox for this task run" }).data ,
871866 status = status .HTTP_400_BAD_REQUEST ,
872867 )
873868
874- if not self ._is_valid_sandbox_url (sandbox_url ):
869+ if not self ._is_valid_sandbox_url (run_state . sandbox_url ):
875870 logger .warning (f"Blocked request to disallowed sandbox URL for task run { task_run .id } " )
876871 return Response (
877872 ErrorResponseSerializer ({"error" : "Invalid sandbox URL" }).data ,
@@ -884,8 +879,6 @@ def command(self, request, pk=None, **kwargs):
884879 distinct_id = request .user .distinct_id ,
885880 )
886881
887- sandbox_connect_token = state .get ("sandbox_connect_token" )
888-
889882 command_payload : dict = {
890883 "jsonrpc" : request .validated_data ["jsonrpc" ],
891884 "method" : request .validated_data ["method" ],
@@ -897,9 +890,9 @@ def command(self, request, pk=None, **kwargs):
897890
898891 try :
899892 agent_response = self ._proxy_command_to_agent_server (
900- sandbox_url = sandbox_url ,
893+ sandbox_url = run_state . sandbox_url ,
901894 connection_token = connection_token ,
902- sandbox_connect_token = sandbox_connect_token ,
895+ sandbox_connect_token = run_state . sandbox_connect_token ,
903896 payload = command_payload ,
904897 )
905898
0 commit comments