From f119d26331c8789ee7bae560e0c7955b5c31ced2 Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Fri, 13 Mar 2026 00:04:44 +0000 Subject: [PATCH 1/7] Add orchestrator dispatcher to simplify fault tolerance --- matrix/agents/agent_actor.py | 123 ++++++---- matrix/agents/config/agents/sink.yaml | 4 +- matrix/agents/config/simulation.yaml | 8 +- matrix/agents/dispatcher.py | 131 ++++++++++ matrix/agents/p2p_agents.py | 101 ++++---- matrix/agents/sink.py | 339 +------------------------- 6 files changed, 280 insertions(+), 426 deletions(-) create mode 100644 matrix/agents/dispatcher.py diff --git a/matrix/agents/agent_actor.py b/matrix/agents/agent_actor.py index 5cca1d1..a4ca632 100644 --- a/matrix/agents/agent_actor.py +++ b/matrix/agents/agent_actor.py @@ -16,7 +16,7 @@ from omegaconf import DictConfig from ray.util.metrics import Counter, Gauge -from .agent_utils import send_with_retry, setup_logging +from .agent_utils import setup_logging from .orchestrator import BaseResourceClient, Orchestrator logger = logging.getLogger(__name__) @@ -34,6 +34,9 @@ def __init__( config: DictConfig, resources: dict[str, BaseResourceClient], sink: ray.actor.ActorHandle, + dispatcher_name: str = None, + namespace: str = None, + ray_name: str = None, ): # PATCH FIRST - before any HTTP clients are created self._patch_getproxies() @@ -65,8 +68,11 @@ def __init__( # Store sink reference and start event loop # For Sink actor itself, sink will be None (set later via _set_self_as_sink) self.sink = sink - # Local team cache for fast actor lookup, updated from sink on failure - self._local_team_cache: Dict[str, List[ray.actor.ActorHandle]] = {} + # Dispatcher name + namespace for name-based resolution (None for Sink) + self.dispatcher_name = dispatcher_name + self.namespace = namespace + self.ray_name = ray_name # This agent's Ray actor name (for dispatcher identification) + self.dispatcher = None # resolved lazily in _event_loop self.event_loop_task: Optional[asyncio.Task] = ( asyncio.get_event_loop().create_task(self._event_loop()) @@ -81,13 +87,6 @@ def __init__( "Total number of messages processed by this agent", {}, ), - ( - "queue_size", - Gauge, - "agent_queue_size", - "Current queue size for this agent", - {}, - ), ( "messages_received", Counter, @@ -146,6 +145,17 @@ def __init__( {}, ), ] + # queue_size only makes sense for Sink (local queue); dispatched agents use Dispatcher's metric + if self.dispatcher_name is None: + metrics_config.append( + ( + "queue_size", + Gauge, + "agent_queue_size", + "Current queue size for this agent", + {}, + ), + ) self._init_metrics(metrics_config) @staticmethod @@ -219,7 +229,10 @@ async def _handle_task_exception(orchestrator, msg): orchestrator._append( self.agent_id, {"status_ok": False, "error": msg}, self.sink ) - await self.sink.receive_message.remote(orchestrator) + if self.dispatcher is not None: + await self.dispatcher.submit_error.remote(orchestrator, orchestrator.id) + else: + await self.sink.receive_message.remote(orchestrator) def _log_exceptions(task): try: @@ -237,31 +250,58 @@ def _log_exceptions(task): self.tasks_completed.inc() # type: ignore[attr-defined] self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] - while self.running: - orchestrator = await self.queue.get() - if orchestrator is None: # Shutdown sentinel - break - latency = time.time() - orchestrator.enqueue_timestamp - self.dequeue_latency.set(latency) # type: ignore[attr-defined] - if self.ENABLE_INSTRUMENTATION: - orchestrator.append_instrumentation( - self.dequeue_latency, self.agent_id, latency # type: ignore[attr-defined] - ) - - # Update queue size after getting message - self.queue_size.set(self.queue.qsize()) # type: ignore[attr-defined] - - task = asyncio.create_task(self._handle(orchestrator)) - # Attach orchestrator to task for error logging - task._orchestrator = orchestrator # type: ignore[attr-defined] + if self.dispatcher_name is not None: + # Pull-based mode: resolve dispatcher by name, notify of (re)start + self.dispatcher = ray.get_actor(self.dispatcher_name, namespace=self.namespace) + await self.dispatcher.agent_started.remote(self.ray_name) + + while self.running: + orchestrator = await self.dispatcher.checkout.remote(self.ray_name) + if orchestrator is None: # Shutdown sentinel + break + latency = time.time() - orchestrator.enqueue_timestamp + self.dequeue_latency.set(latency) # type: ignore[attr-defined] + if self.ENABLE_INSTRUMENTATION: + orchestrator.append_instrumentation( + self.dequeue_latency, self.agent_id, latency # type: ignore[attr-defined] + ) + + task = asyncio.create_task(self._handle(orchestrator)) + task._orchestrator = orchestrator # type: ignore[attr-defined] + + self.pending_tasks.add(task) + self.tasks_started.inc() # type: ignore[attr-defined] + self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] - self.pending_tasks.add(task) - self.tasks_started.inc() # type: ignore[attr-defined] - self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] + task.add_done_callback(self.pending_tasks.discard) + task.add_done_callback(_log_exceptions) + else: + # Local-queue mode (Sink) + while self.running: + orchestrator = await self.queue.get() + if orchestrator is None: # Shutdown sentinel + break + latency = time.time() - orchestrator.enqueue_timestamp + self.dequeue_latency.set(latency) # type: ignore[attr-defined] + if self.ENABLE_INSTRUMENTATION: + orchestrator.append_instrumentation( + self.dequeue_latency, self.agent_id, latency # type: ignore[attr-defined] + ) + + # Update queue size after getting message + self.queue_size.set(self.queue.qsize()) # type: ignore[attr-defined] + + task = asyncio.create_task(self._handle(orchestrator)) + # Attach orchestrator to task for error logging + task._orchestrator = orchestrator # type: ignore[attr-defined] + + self.pending_tasks.add(task) + self.tasks_started.inc() # type: ignore[attr-defined] + self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] - # Clean up completed tasks - task.add_done_callback(self.pending_tasks.discard) - task.add_done_callback(_log_exceptions) + # Clean up completed tasks + task.add_done_callback(self.pending_tasks.discard) + task.add_done_callback(_log_exceptions) if self.pending_tasks: await asyncio.gather(*self.pending_tasks, return_exceptions=True) @@ -275,10 +315,6 @@ async def _handle(self, orchestrator: Orchestrator): result = await self.postprocess(orchestrator, result) if self.agent_id != "_sink": next_state = await orchestrator.update(result, self, self.logger) - if await next_state.is_done(): - next_agent_name = "_sink" - else: - next_agent_name = next_state.current_agent() # temporary if self.ENABLE_INSTRUMENTATION: @@ -293,14 +329,9 @@ async def _handle(self, orchestrator: Orchestrator): self.handle_latency, self.agent_id, latency # type: ignore[attr-defined] ) - # Send to next agent with fault-tolerant retry - self._local_team_cache = await send_with_retry( - next_state, - next_agent_name, - self.sink, - self._local_team_cache, - self.logger, - ) + if self.dispatcher is not None: + # Submit to dispatcher for deterministic routing + await self.dispatcher.submit.remote(next_state, orchestrator.id) # Update last message time after successful send self.last_message_time = time.time() else: diff --git a/matrix/agents/config/agents/sink.yaml b/matrix/agents/config/agents/sink.yaml index fb55dab..f3f435e 100644 --- a/matrix/agents/config/agents/sink.yaml +++ b/matrix/agents/config/agents/sink.yaml @@ -1,9 +1,7 @@ _sink: _target_: matrix.agents.p2p_agents.Sink debug: ${debug} - dead_orchestrator_tracking: ${dead_orchestrator_tracking} - max_concurrent_tasks: ${max_concurrent_tasks} - + ray_resources: num_cpus: 2 # resources: diff --git a/matrix/agents/config/simulation.yaml b/matrix/agents/config/simulation.yaml index c4f6ab7..a8dbc28 100644 --- a/matrix/agents/config/simulation.yaml +++ b/matrix/agents/config/simulation.yaml @@ -2,15 +2,19 @@ matrix: cluster_id: null matrix_dir: null - + # Simulation framework settings max_concurrent_tasks: 100 rate_limit_enqueue: false +# Dispatcher settings +dispatcher: + ray_resources: + num_cpus: 1 + debug: false seed: 42 num_trial: 1 -dead_orchestrator_tracking: false # Default output settings output: diff --git a/matrix/agents/dispatcher.py b/matrix/agents/dispatcher.py new file mode 100644 index 0000000..aceee2f --- /dev/null +++ b/matrix/agents/dispatcher.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import logging +import time +from typing import Dict, Optional + +import ray +from ray.util.metrics import Gauge + +from .orchestrator import DeadOrchestrator, Orchestrator + +logger = logging.getLogger(__name__) + + +class Dispatcher: + """ + Dispatcher actor (one per role) that acts as a message broker. + + Agents pull work from their Dispatcher via checkout(), process it, + and submit results back via submit(). The Dispatcher tracks exactly + which agent has which orchestrator. + + Dead-agent detection is handled by agent restarts: when an agent + restarts (max_restarts=-1), it calls agent_started() which tombstones + any previously checked-out orchestrators for that agent. + """ + + def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str): + self.role = role + self.sink = sink + self.namespace = namespace + + self.incoming_queue: asyncio.Queue = asyncio.Queue() + # orch_id -> (agent_ray_name, orchestrator) + self.checked_out: Dict[str, tuple[str, Orchestrator]] = {} + + # Other role Dispatchers for forwarding (handles are stable — dispatchers don't restart) + self.dispatchers: Dict[str, ray.actor.ActorHandle] = {} + # Known agents (ray actor names), populated by agent_started() + self.known_agents: set[str] = set() + + self.logger = logging.getLogger(f"Dispatcher[{role}]") + + self.queue_size = Gauge( + "dispatcher_queue_size", + description="Current queue size for this dispatcher", + tag_keys=("role",), + ) + self.queue_size.set_default_tags({"role": role}) + + async def set_dispatchers(self, dispatchers: Dict[str, ray.actor.ActorHandle]): + """Wire this Dispatcher to other role Dispatchers for forwarding.""" + self.dispatchers = dispatchers + + async def enqueue(self, orchestrator: Orchestrator): + """Called by other Dispatchers or the framework to queue work.""" + orchestrator.enqueue_timestamp = time.time() + await self.incoming_queue.put(orchestrator) + self.queue_size.set(self.incoming_queue.qsize()) + + async def checkout(self, agent_ray_name: str) -> Optional[Orchestrator]: + """ + Agent pulls work, blocks until available. + Returns None as shutdown sentinel. + """ + orchestrator = await self.incoming_queue.get() + self.queue_size.set(self.incoming_queue.qsize()) + if orchestrator is None: + return None + self.checked_out[orchestrator.id] = (agent_ray_name, orchestrator) + return orchestrator + + async def submit(self, processed_orch: Orchestrator, orch_id: str): + """ + Agent acks completion. Dispatcher removes from checked_out, + checks is_done()/current_agent(), forwards to target Dispatcher or Sink. + """ + self.checked_out.pop(orch_id, None) + + if await processed_orch.is_done(): + await self.sink.receive_message.remote(processed_orch) + else: + next_role = processed_orch.current_agent() + if next_role == "_sink": + await self.sink.receive_message.remote(processed_orch) + elif next_role in self.dispatchers: + await self.dispatchers[next_role].enqueue.remote(processed_orch) + else: + self.logger.error( + f"Unknown target role '{next_role}' for orch {orch_id}, sending to sink" + ) + await self.sink.receive_message.remote(processed_orch) + + async def submit_error(self, orchestrator: Orchestrator, orch_id: str): + """Agent error ack. Forward directly to Sink.""" + self.checked_out.pop(orch_id, None) + await self.sink.receive_message.remote(orchestrator) + + async def agent_started(self, agent_ray_name: str): + """ + Called on agent (re)start. Registers the agent and tombstones any + previously checked-out orchestrators for it (handles restart race). + """ + self.known_agents.add(agent_ray_name) + to_tombstone = [ + (oid, orch) + for oid, (aname, orch) in self.checked_out.items() + if aname == agent_ray_name + ] + for oid, orch in to_tombstone: + del self.checked_out[oid] + self.logger.warning( + f"Agent {agent_ray_name} restarted, tombstoning orch {oid}" + ) + dead = DeadOrchestrator( + oid, error=f"Agent {agent_ray_name} restarted while processing" + ) + await self.sink.receive_message.remote(dead) + + async def shutdown(self): + """Put None sentinels in queue (one per known agent) to stop their loops.""" + for _ in self.known_agents: + await self.incoming_queue.put(None) + + async def check_health(self): + return True diff --git a/matrix/agents/p2p_agents.py b/matrix/agents/p2p_agents.py index 827b04d..c7c259b 100644 --- a/matrix/agents/p2p_agents.py +++ b/matrix/agents/p2p_agents.py @@ -33,7 +33,6 @@ HistPair, RayDict, get_ray_actor_class, - send_with_retry, setup_logging, ) from .dataset_loader import BaseDatasetLoader @@ -43,13 +42,13 @@ Orchestrator, SequentialOrchestrator, ) +from .dispatcher import Dispatcher from .sink import Sink # Re-export all public names for backward compatibility __all__ = [ "RayDict", "HistPair", - "send_with_retry", "Orchestrator", "SequentialOrchestrator", "DeadOrchestrator", @@ -60,6 +59,7 @@ "ContainerExecutionAgent", "LLMAgentActor", "Sink", + "Dispatcher", "ScalableTeamManager", "P2PAgentFramework", "main", @@ -88,15 +88,21 @@ def done(self): class ScalableTeamManager: - """Manages teams with multiple actors per role using load balancers when needed""" + """Manages teams with multiple actors per role using Dispatchers for routing and load balancing""" - def __init__(self, simulation_id: str): + def __init__(self, simulation_id: str, dispatcher_config: Optional[DictConfig] = None): self.simulation_id = simulation_id self.teamConfig: Dict[str, Tuple[Type, DictConfig]] = {} - # Team registry config: role -> (count, namespace) for actor lookup - self.team_registry_config: Dict[str, Tuple[int, str]] = {} # Sink actor handle - must be created first self.sink: Optional[ray.actor.ActorHandle] = None + # Dispatcher per role (excluding sink) + self.dispatchers: Dict[str, ray.actor.ActorHandle] = {} + # Dispatcher ray resource config + self.dispatcher_ray_resources: dict[str, Any] = {} + if dispatcher_config and "ray_resources" in dispatcher_config: + self.dispatcher_ray_resources = OmegaConf.to_container( + dispatcher_config["ray_resources"], resolve=True + ) def create_role(self, role_name: str, agent_config: DictConfig, resources): """Create agents for a role. _sink must be created first.""" @@ -117,8 +123,22 @@ def create_role(self, role_name: str, agent_config: DictConfig, resources): # Sink should not restart; other actors restart infinitely max_restarts = 0 if is_sink else -1 + # Create Dispatcher for non-sink roles + dispatcher = None + if not is_sink: + DispatcherActor = ray.remote(Dispatcher) + dispatcher = DispatcherActor.options( + name=f"dispatcher_{role_name}", + namespace=self.simulation_id, + max_restarts=0, + **self.dispatcher_ray_resources, + ).remote(role=role_name, sink=self.sink, namespace=self.simulation_id) + self.dispatchers[role_name] = dispatcher + logger.info(f"Created dispatcher for role: {role_name}") + agents = [] for i in range(count): + ray_name = f"{role_name}_{i}" kwargs = { "id": f"{self.simulation_id}_{role_name}_{i}", "agent_id": role_name, @@ -127,9 +147,12 @@ def create_role(self, role_name: str, agent_config: DictConfig, resources): } if not is_sink: kwargs["sink"] = self.sink + kwargs["dispatcher_name"] = f"dispatcher_{role_name}" + kwargs["namespace"] = self.simulation_id + kwargs["ray_name"] = ray_name agent = agent_class.options( - name=f"{role_name}_{i}", + name=ray_name, namespace=self.simulation_id, max_restarts=max_restarts, **ray_resources, @@ -147,24 +170,22 @@ def create_role(self, role_name: str, agent_config: DictConfig, resources): agent_class.__ray_metadata__.modified_class, agent_config, ) - # Only add non-sink roles to registry (sink won't restart) - if not is_sink: - self.team_registry_config[role_name] = (count, self.simulation_id) return agents async def initialize_team(self, team: Dict[str, List[ray.actor.ActorHandle]]): - """Initialize all agents with team references. + """Initialize all agents and wire Dispatchers. Args: team: Dict mapping role -> list of actor handles (from create_role return values) """ - # Use handles from create_role directly (avoid race with ray.get_actor) + # Health-check all actors and dispatchers all_actors = [self.sink] for role_handles in team.values(): all_actors.extend(role_handles) + all_actors.extend(self.dispatchers.values()) - logger.info(f"Checking Ray actor health for {len(all_actors)} actors") + logger.info(f"Checking Ray actor health for {len(all_actors)} actors (including dispatchers)") try: await asyncio.wait_for( asyncio.gather( @@ -179,18 +200,31 @@ async def initialize_team(self, team: Dict[str, List[ray.actor.ActorHandle]]): raise e logger.info("Checking Ray actor health done...") - # Initialize Sink's team registry with verified handles (avoid re-lookup race) - if self.sink is not None: - await self.sink.set_team_registry.remote(self.team_registry_config, team) + # Wire Dispatchers to each other + for role_name, dispatcher in self.dispatchers.items(): + await dispatcher.set_dispatchers.remote(self.dispatchers) + + logger.info(f"Dispatchers wired: {list(self.dispatchers.keys())}") def get_team_config(self): """Get team config dictionary for orchestrator routing""" return self.teamConfig async def shutdown(self): - """Shutdown all actors via sink's registry""" + """Shutdown Dispatchers (which send sentinels to agents), then Sink.""" + # Shutdown dispatchers first (sends sentinels to agents) + for role_name, dispatcher in self.dispatchers.items(): + try: + await dispatcher.shutdown.remote() + except Exception as e: + logger.warning(f"Error shutting down dispatcher {role_name}: {repr(e)}") + + # Then shutdown sink if self.sink: - await self.sink.shutdown_all.remote() + try: + await self.sink.shutdown.remote() + except Exception as e: + logger.warning(f"Error shutting down sink: {repr(e)}") class P2PAgentFramework: @@ -209,13 +243,12 @@ def __init__(self, sim_index: int, cfg: DictConfig): self.semaphore = asyncio.Semaphore(self.max_concurrent_tasks) self.sink: Sink = None # type: ignore[assignment] - self.team_manager = ScalableTeamManager(self.simulation_id) + self.team_manager = ScalableTeamManager( + self.simulation_id, + dispatcher_config=self.cfg.get("dispatcher"), + ) self.resources: Dict[str, BaseResourceClient] = {} - # Local team cache for latency-sensitive _process_item - # Updated from sink on actor failure - self._local_team_cache: Dict[str, List[ray.actor.ActorHandle]] = {} - random.seed(self.cfg["seed"]) self.num_trial = self.cfg["num_trial"] if self.num_trial > 1: @@ -256,9 +289,6 @@ async def create_team( await self.team_manager.initialize_team(team) self.sink = self.team_manager.sink # type: ignore[assignment] - # Initialize local team cache from sink - self._local_team_cache = await self.sink.get_team_snapshot.remote() # type: ignore[attr-defined] - async def _progress_task(self): async def _update_progress(): done = await self.sink.get_progress.remote() # type: ignore[attr-defined] @@ -342,22 +372,11 @@ async def _process_item(self, trial_item: Tuple[int, Dict[str, Any]]): logger.debug(f"done Init {orchestrator.id}") logger.debug(f"Enqueue: {orchestrator.id}") - # Register as in-flight before sending to first agent - if self.cfg.dead_orchestrator_tracking: - await self.sink.register_inflight.remote(orchestrator.id) # type: ignore[attr-defined] - - # Send to first agent with local cache for latency, fallback to sink on error + # Enqueue to the first agent's Dispatcher try: - self._local_team_cache = await send_with_retry( - orchestrator, - first_agent_role, - self.sink, # type: ignore[arg-type] - self._local_team_cache, - logger, - ) - except RuntimeError as e: - # All retries exhausted - send to sink as error - logger.error(str(e)) + await self.team_manager.dispatchers[first_agent_role].enqueue.remote(orchestrator) + except Exception as e: + logger.error(f"Failed to enqueue to dispatcher for {first_agent_role}: {repr(e)}") orchestrator.status["error"] = f"Failed to reach {first_agent_role}: {e}" await self.sink.receive_message.remote(orchestrator) # type: ignore[attr-defined] return diff --git a/matrix/agents/sink.py b/matrix/agents/sink.py index d474345..bdeaeb2 100644 --- a/matrix/agents/sink.py +++ b/matrix/agents/sink.py @@ -7,10 +7,9 @@ import asyncio import json import os -import random import time from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional import ray import zstandard as zstd @@ -33,15 +32,6 @@ # @ray.remote class Sink(AgentActor): - # Constants for team registry - REFRESH_INTERVAL = 30.0 # seconds between periodic refreshes - GET_ACTOR_TIMEOUT = 60.0 # default timeout for get_actor calls - GET_ACTOR_RETRY_INTERVAL = 1.0 # seconds between retries - - # dead detection - IDLE_TIMEOUT = 60.0 # seconds of idle time before checking for dead tasks - MAX_NEW_ZOMBIE = 10 # at most mark this number at once - LATE_ARRIVAL_INCR = 5 # make it harder to mark zombie for each late arrivals def __init__( self, @@ -59,27 +49,6 @@ def __init__( self.pending_writes: int = ( 0 # Track in-progress writes to avoid closing file prematurely ) - - # Team registry: Sink is the source of truth for actor handles - self._team: Dict[str, List[ray.actor.ActorHandle]] = {} - self._team_config: Dict[str, Tuple[int, str]] = {} # role -> (count, namespace) - self._team_lock = asyncio.Lock() - self._refresh_task: Optional[asyncio.Task] = None - - # Optimistic timeout tracking for dead orchestrator detection - self.dead_orchestrator_tracking = config.dead_orchestrator_tracking - self.dead_order_window = config.max_concurrent_tasks - # Registration order tracking: when task N completes, tasks before N-max_concurrent_tasks are zombies - self.registration_counter: int = 0 - self.inflight_orchestrators: dict[str, int] = {} # id -> registration_order - self.inflight_order: list[tuple[int, str]] = ( - [] - ) # sorted (order, id) for efficient oldest iteration, lazy update, have stale items - self.zombie_orchestrators: set[str] = set() # set of zombie orchestrator ids - - # Idle detection for dead task recovery - self.last_message_time: float = time.time() - self._idle_check_task: Optional[asyncio.Task] = None self.num_dead: int = 0 # Counter for dead/lost orchestrators additional_metrics_config: list[tuple[str, type, str, str, dict[str, Any]]] = [ @@ -130,12 +99,8 @@ async def set_metrics_output( async def set_num_inputs(self, num_inputs: int): self.num_inputs = num_inputs - # Start idle check task now that we know the total inputs - if self._idle_check_task is None: - self._idle_check_task = asyncio.create_task(self._idle_check_loop()) async def preprocess(self, orchestrator: "Orchestrator"): - # Update last message time for idle detection self.last_message_time = time.time() def _write_output(output_data, output_path): @@ -169,8 +134,8 @@ def _write_output(output_data, output_path): finally: self.pending_writes -= 1 # Always decrement, even on error - # Get registration order and remove from inflight tracking - completed_order = self.inflight_orchestrators.pop(orchestrator.id, None) + # Increment num_done for ALL arrivals (normal + dead) + self.num_done += 1 if is_tombstone: self.num_dead += 1 @@ -179,58 +144,13 @@ def _write_output(output_data, output_path): latency = orchestrator.init_timestamp - orchestrator.creation_timestamp self.task_init_latency.set(latency) # type: ignore[attr-defined] - # Check if this orchestrator was in zombie set (came back from the dead) - is_zombie_return = orchestrator.id in self.zombie_orchestrators - if is_zombie_return: - # Remove from zombie set - it actually completed - self.zombie_orchestrators.discard(orchestrator.id) - self.logger.info( - f"Zombie orchestrator {orchestrator.id} returned, removing from zombie set" - ) - self.dead_order_window += self.LATE_ARRIVAL_INCR - - if not is_zombie_return: - self.num_done += 1 - - # Position-based zombie detection: when a non-error orchestrator completes, - # all tasks registered more than max_concurrent_tasks before it should be done - if ( - self.dead_orchestrator_tracking - and not orchestrator.is_error() - and completed_order is not None - ): - threshold = completed_order - self.dead_order_window - if threshold > 0: - to_zombify = 0 - # Iterate through oldest entries first, using lazy deletion - # (skip entries already removed from inflight_orchestrators) - while ( - self.inflight_order - and self.inflight_order[0][0] <= threshold - and to_zombify < self.MAX_NEW_ZOMBIE - ): - order, orch_id = self.inflight_order.pop(0) - # Only zombify if still in inflight (not already completed) - if orch_id in self.inflight_orchestrators: - del self.inflight_orchestrators[orch_id] - self.zombie_orchestrators.add(orch_id) - to_zombify += 1 - # Increase num_done to release semaphore - self.num_done += 1 - - if to_zombify: - self.logger.info( - f"Moved {to_zombify} orchestrators to zombie (threshold order: {threshold})" - ) - if self.metrics_accumulator: self.metrics_accumulator.accumulate(orchestrator) - # Don't close output file if there are zombies or pending writes + # Close output file when all tasks are done and no pending writes if ( self.num_inputs is not None and self.num_done >= self.num_inputs - and not self.zombie_orchestrators and self.pending_writes == 0 ): self.output_file.close() @@ -238,18 +158,8 @@ def _write_output(output_data, output_path): return {"orchestrator": orchestrator} async def get_progress(self) -> int: - # Cap num_done when there are zombies to prevent input pipeline from finishing early - if self.zombie_orchestrators and self.num_inputs is not None: - return min(self.num_done, self.num_inputs - 1) return self.num_done - async def register_inflight(self, orchestrator_id: str): - """Register an orchestrator as in-flight (before sending to first agent).""" - self.registration_counter += 1 - order = self.registration_counter - self.inflight_orchestrators[orchestrator_id] = order - self.inflight_order.append((order, orchestrator_id)) - async def get_overall_metrics(self) -> dict[str, Any] | None: return ( self.metrics_accumulator.done() @@ -265,35 +175,9 @@ async def check_health(self): return True async def shutdown(self): - """Gracefully shutdown the Sink agent and cancel background tasks.""" - if self._refresh_task is not None: - self._refresh_task.cancel() - try: - await self._refresh_task - except asyncio.CancelledError: - pass - self._refresh_task = None - if self._idle_check_task is not None: - self._idle_check_task.cancel() - try: - await self._idle_check_task - except asyncio.CancelledError: - pass - self._idle_check_task = None + """Gracefully shutdown the Sink agent.""" await super().shutdown() - async def shutdown_all(self): - """Shutdown all actors in the team registry, then shutdown self.""" - async with self._team_lock: - for role, actors in self._team.items(): - for actor in actors: - try: - await actor.shutdown.remote() - except Exception as e: - self.logger.warning(f"Error shutting down {role}: {repr(e)}") - # Finally shutdown self - await self.shutdown() - async def register_object(self, obj: list[ray.ObjectRef]): o = obj[0] self.ray_objects[o.hex()] = o # type: ignore[attr-defined] @@ -301,216 +185,3 @@ async def register_object(self, obj: list[ray.ObjectRef]): async def unregister_object(self, obj: list[ray.ObjectRef]): for o in obj: self.ray_objects.pop(o.hex(), None) # type: ignore[attr-defined] - - # ==== Team Registry Methods ==== - async def set_team_registry( - self, - team_config: Dict[str, Tuple[int, str]], - initial_team: Optional[Dict[str, List[ray.actor.ActorHandle]]] = None, - ): - """ - Initialize the team registry in Sink. - - Args: - team_config: Dict mapping role -> (count, namespace) for actor lookup - initial_team: Optional pre-verified actor handles to avoid initial lookup race - """ - async with self._team_lock: - self._team_config = team_config - self._team = initial_team if initial_team is not None else {} - - # Only do initial refresh if we didn't get pre-verified handles - if initial_team is None: - await self._refresh_team_internal() - - # Start periodic refresh task - if self._refresh_task is None: - self._refresh_task = asyncio.create_task(self._periodic_refresh()) - self.logger.info( - f"Team registry initialized with roles: {list(team_config.keys())}" - ) - - async def _periodic_refresh(self): - """Background task to periodically refresh actor handles.""" - while self.running: - try: - await asyncio.sleep(self.REFRESH_INTERVAL) - await self._refresh_team_internal() - except asyncio.CancelledError: - break - except Exception as e: - self.logger.warning(f"Error in periodic team refresh: {repr(e)}") - - async def _refresh_team_internal(self): - """Internal method to refresh all actor handles from Ray.""" - async with self._team_lock: - for role, (count, namespace) in self._team_config.items(): - new_handles = [] - for i in range(count): - actor_name = f"{role}_{i}" - try: - handle = ray.get_actor(actor_name, namespace=namespace) - new_handles.append(handle) - except ValueError: - # Actor not found (not yet restarted) - self.logger.warning( - f"Actor {actor_name} not found in namespace {namespace}" - ) - if new_handles: - self._team[role] = new_handles - self.logger.debug(f"Team refreshed: {list(self._team.keys())}") - - async def force_refresh(self): - """Force an immediate refresh of all actor handles.""" - await self._refresh_team_internal() - - async def _write_tombstone(self, orchestrator_id: str): - """ - Write a tombstone record for a lost orchestrator via normal Sink flow. - - Args: - orchestrator_id: The ID of the lost orchestrator - """ - self.logger.debug( - f"Creating tombstone for lost orchestrator: {orchestrator_id}" - ) - - # Create a DeadOrchestrator and process it through normal flow - dead_orch = DeadOrchestrator(orchestrator_id) - await self.receive_message(dead_orch) - - async def _idle_check_loop(self): - """Background task to detect idle state and write tombstones for lost tasks.""" - while self.running: - try: - await asyncio.sleep(10.0) # Check every 10 seconds - - # Skip if we're done and no zombies - has_zombies = bool(self.zombie_orchestrators) - if ( - self.num_inputs is not None - and self.num_done >= self.num_inputs - and not has_zombies - ): - break - - # Check if we've been idle long enough - idle_time = time.time() - self.last_message_time - if idle_time < self.IDLE_TIMEOUT: - continue - - # Check if all actors have empty queues and no pending tasks - all_idle = await self._check_all_actors_idle() - if not all_idle: - continue - - for orch_id in self.inflight_orchestrators: - self.zombie_orchestrators.add(orch_id) - self.num_done += 1 - self.inflight_orchestrators.clear() - - # All actors are idle - confirm zombies are dead and write tombstones - if self.zombie_orchestrators: - self.logger.warning( - f"System idle for {idle_time:.1f}s with {len(self.zombie_orchestrators)} " - f"zombie tasks. Writing tombstones to confirm dead." - ) - tasks = [ - self._write_tombstone(orch_id) - for orch_id in self.zombie_orchestrators - ] - self.zombie_orchestrators.clear() - await asyncio.gather(*tasks) - - # Wait for tombstones to be processed through the queue - while self.queue.qsize() > 0 or len(self.pending_tasks) > 0: - await asyncio.sleep(1) - # File will be closed by preprocess when last tombstone is written - break - - # Check if we're still short of expected tasks - if self.num_inputs is not None and self.num_done < self.num_inputs: - self.logger.error( - f"Completed {self.num_done} tasks but expected {self.num_inputs}. " - f"Some tasks may have been lost before registration." - ) - break - - except asyncio.CancelledError: - break - except Exception as e: - self.logger.warning(f"Error in idle check loop: {repr(e)}") - - async def _check_all_actors_idle(self) -> bool: - """Check if all actors have zero queue + pending tasks and idle timestamps.""" - now = time.time() - async with self._team_lock: - all_actors = [actor for actors in self._team.values() for actor in actors] - - for actor in all_actors: - try: - count, last_msg_time = await asyncio.wait_for( - actor.get_active_count.remote(), - timeout=5.0, - ) - if count > 0: - return False - # Also check if actor's last message time is within IDLE_TIMEOUT - if now - last_msg_time < self.IDLE_TIMEOUT: - return False - except Exception as e: - # If we can't reach an actor, it might be dead - don't consider system idle - self.logger.debug(f"Failed to get active count from actor: {repr(e)}") - return False - - return True - - async def get_actor( - self, role: str, timeout: Optional[float] = None, force_refresh: bool = False - ) -> ray.actor.ActorHandle: - """ - Get a random actor handle for the given role. - Blocks until an actor is available or timeout expires. - - Args: - role: The role name to get an actor for - timeout: Maximum time to wait in seconds (default: GET_ACTOR_TIMEOUT) - force_refresh: If True, refresh handles before returning - - Returns: - A random actor handle for the role - - Raises: - TimeoutError: If no actor is available within timeout - KeyError: If role is not registered - """ - if timeout is None: - timeout = self.GET_ACTOR_TIMEOUT - - if force_refresh: - await self._refresh_team_internal() - - start_time = time.time() - while True: - async with self._team_lock: - if role in self._team and self._team[role]: - return random.choice(self._team[role]) - - # Check timeout - elapsed = time.time() - start_time - if elapsed >= timeout: - raise TimeoutError( - f"Timeout waiting for actor with role '{role}' after {timeout}s" - ) - - # Wait and retry - self.logger.debug( - f"No actor available for role '{role}', retrying in {self.GET_ACTOR_RETRY_INTERVAL}s" - ) - await asyncio.sleep(self.GET_ACTOR_RETRY_INTERVAL) - await self._refresh_team_internal() - - async def get_team_snapshot(self) -> Dict[str, List[ray.actor.ActorHandle]]: - """Get a snapshot of the current team map.""" - async with self._team_lock: - return dict(self._team) From 58c3de401372f5dbfaae41822642fef837f106f3 Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Fri, 13 Mar 2026 00:21:19 +0000 Subject: [PATCH 2/7] retry send to dispatcher and sink in case of network issue --- matrix/agents/agent_actor.py | 49 +++++++++++++++++++++++++++++++++--- matrix/agents/agent_utils.py | 44 ++++++++++++++++++++++++++++++++ matrix/agents/dispatcher.py | 38 +++++++++++++++++++++++----- 3 files changed, 121 insertions(+), 10 deletions(-) diff --git a/matrix/agents/agent_actor.py b/matrix/agents/agent_actor.py index a4ca632..9e39704 100644 --- a/matrix/agents/agent_actor.py +++ b/matrix/agents/agent_actor.py @@ -16,8 +16,8 @@ from omegaconf import DictConfig from ray.util.metrics import Counter, Gauge -from .agent_utils import setup_logging -from .orchestrator import BaseResourceClient, Orchestrator +from .agent_utils import remote_call_with_retry, setup_logging +from .orchestrator import BaseResourceClient, DeadOrchestrator, Orchestrator logger = logging.getLogger(__name__) @@ -230,7 +230,26 @@ async def _handle_task_exception(orchestrator, msg): self.agent_id, {"status_ok": False, "error": msg}, self.sink ) if self.dispatcher is not None: - await self.dispatcher.submit_error.remote(orchestrator, orchestrator.id) + try: + await remote_call_with_retry( + self.dispatcher.submit_error, + orchestrator, + orchestrator.id, + logger=self.logger, + ) + except Exception as e: + self.logger.error( + f"Failed to submit error to dispatcher for {orchestrator.id}: {repr(e)}" + ) + # Fall back to sending directly to sink + try: + await remote_call_with_retry( + self.sink.receive_message, orchestrator, logger=self.logger + ) + except Exception: + self.logger.error( + f"Failed to send error orch {orchestrator.id} to sink, dropping" + ) else: await self.sink.receive_message.remote(orchestrator) @@ -331,7 +350,29 @@ async def _handle(self, orchestrator: Orchestrator): if self.dispatcher is not None: # Submit to dispatcher for deterministic routing - await self.dispatcher.submit.remote(next_state, orchestrator.id) + try: + await remote_call_with_retry( + self.dispatcher.submit, + next_state, + orchestrator.id, + logger=self.logger, + ) + except Exception as e: + self.logger.error( + f"Failed to submit orch {orchestrator.id} to dispatcher: {repr(e)}" + ) + dead = DeadOrchestrator( + orchestrator.id, + error=f"Dispatcher unreachable: {e}", + ) + try: + await remote_call_with_retry( + self.sink.receive_message, dead, logger=self.logger + ) + except Exception: + self.logger.error( + f"Failed to tombstone orch {orchestrator.id} to sink, dropping" + ) # Update last message time after successful send self.last_message_time = time.time() else: diff --git a/matrix/agents/agent_utils.py b/matrix/agents/agent_utils.py index 693d87c..40bb321 100644 --- a/matrix/agents/agent_utils.py +++ b/matrix/agents/agent_utils.py @@ -180,6 +180,50 @@ class HistPair(NamedTuple): response: RayDict +REMOTE_CALL_TIMEOUT = 60.0 # seconds per attempt +REMOTE_CALL_RETRIES = 3 +REMOTE_CALL_BACKOFF = 2.0 # seconds between retries + + +async def remote_call_with_retry( + method, + *args, + retries: int = REMOTE_CALL_RETRIES, + timeout: float = REMOTE_CALL_TIMEOUT, + backoff: float = REMOTE_CALL_BACKOFF, + logger: Optional[logging.Logger] = None, +): + """Call a Ray remote method with retry and timeout. + + Args: + method: Ray actor method (e.g. handle.submit) + *args: Arguments to pass to method.remote(*args) + retries: Number of attempts + timeout: Timeout per attempt in seconds + backoff: Seconds to wait between retries + logger: Logger for warnings + + Returns: + The result of the remote call + + Raises: + The last exception if all retries exhausted + """ + last_error = None + for attempt in range(retries): + try: + return await asyncio.wait_for(method.remote(*args), timeout=timeout) + except Exception as e: + last_error = e + if logger: + logger.warning( + f"Remote call failed (attempt {attempt + 1}/{retries}): {repr(e)}" + ) + if attempt < retries - 1: + await asyncio.sleep(backoff) + raise last_error # type: ignore[misc] + + # ==== Utility Functions ==== async def send_with_retry( orchestrator: "Orchestrator", diff --git a/matrix/agents/dispatcher.py b/matrix/agents/dispatcher.py index aceee2f..da41076 100644 --- a/matrix/agents/dispatcher.py +++ b/matrix/agents/dispatcher.py @@ -12,6 +12,7 @@ import ray from ray.util.metrics import Gauge +from .agent_utils import remote_call_with_retry from .orchestrator import DeadOrchestrator, Orchestrator logger = logging.getLogger(__name__) @@ -83,23 +84,37 @@ async def submit(self, processed_orch: Orchestrator, orch_id: str): self.checked_out.pop(orch_id, None) if await processed_orch.is_done(): - await self.sink.receive_message.remote(processed_orch) + await self._send_to_sink(processed_orch) else: next_role = processed_orch.current_agent() if next_role == "_sink": - await self.sink.receive_message.remote(processed_orch) + await self._send_to_sink(processed_orch) elif next_role in self.dispatchers: - await self.dispatchers[next_role].enqueue.remote(processed_orch) + try: + await remote_call_with_retry( + self.dispatchers[next_role].enqueue, + processed_orch, + logger=self.logger, + ) + except Exception as e: + self.logger.error( + f"Failed to forward orch {orch_id} to dispatcher {next_role}: {repr(e)}" + ) + dead = DeadOrchestrator( + orch_id, + error=f"Failed to forward to {next_role}: {e}", + ) + await self._send_to_sink(dead) else: self.logger.error( f"Unknown target role '{next_role}' for orch {orch_id}, sending to sink" ) - await self.sink.receive_message.remote(processed_orch) + await self._send_to_sink(processed_orch) async def submit_error(self, orchestrator: Orchestrator, orch_id: str): """Agent error ack. Forward directly to Sink.""" self.checked_out.pop(orch_id, None) - await self.sink.receive_message.remote(orchestrator) + await self._send_to_sink(orchestrator) async def agent_started(self, agent_ray_name: str): """ @@ -120,7 +135,18 @@ async def agent_started(self, agent_ray_name: str): dead = DeadOrchestrator( oid, error=f"Agent {agent_ray_name} restarted while processing" ) - await self.sink.receive_message.remote(dead) + await self._send_to_sink(dead) + + async def _send_to_sink(self, orchestrator: Orchestrator): + """Send to sink with retry. If sink is unreachable, log and drop.""" + try: + await remote_call_with_retry( + self.sink.receive_message, orchestrator, logger=self.logger + ) + except Exception as e: + self.logger.error( + f"Failed to send orch {orchestrator.id} to sink after retries: {repr(e)}" + ) async def shutdown(self): """Put None sentinels in queue (one per known agent) to stop their loops.""" From 052d325ba5ff1b149d409328f4ad1c4514d9a847 Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Fri, 13 Mar 2026 01:21:36 +0000 Subject: [PATCH 3/7] checkout can't be blocking --- matrix/agents/agent_actor.py | 33 +++++++++++++++++++++++++++++---- matrix/agents/dispatcher.py | 23 +++++++++++++++-------- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/matrix/agents/agent_actor.py b/matrix/agents/agent_actor.py index 9e39704..96370e2 100644 --- a/matrix/agents/agent_actor.py +++ b/matrix/agents/agent_actor.py @@ -17,6 +17,7 @@ from ray.util.metrics import Counter, Gauge from .agent_utils import remote_call_with_retry, setup_logging +from .dispatcher import Dispatcher from .orchestrator import BaseResourceClient, DeadOrchestrator, Orchestrator logger = logging.getLogger(__name__) @@ -271,13 +272,37 @@ def _log_exceptions(task): if self.dispatcher_name is not None: # Pull-based mode: resolve dispatcher by name, notify of (re)start - self.dispatcher = ray.get_actor(self.dispatcher_name, namespace=self.namespace) - await self.dispatcher.agent_started.remote(self.ray_name) + try: + self.dispatcher = ray.get_actor(self.dispatcher_name, namespace=self.namespace) + except Exception as e: + self.logger.error( + f"Agent {self.id} failed to find dispatcher {self.dispatcher_name}: {repr(e)}" + ) + return + try: + await self.dispatcher.agent_started.remote(self.ray_name) + except Exception as e: + self.logger.error( + f"Agent {self.id} failed to call agent_started on dispatcher: {repr(e)}" + ) + return while self.running: - orchestrator = await self.dispatcher.checkout.remote(self.ray_name) - if orchestrator is None: # Shutdown sentinel + try: + result = await self.dispatcher.checkout.remote(self.ray_name) + except Exception as e: + self.logger.error( + f"Agent {self.id} checkout failed: {repr(e)}, retrying..." + ) + await asyncio.sleep(1.0) + continue + if result == Dispatcher.SHUTDOWN_SENTINEL: break + if result is None: + # Queue empty, poll again after a short sleep + await asyncio.sleep(0.1) + continue + orchestrator = result latency = time.time() - orchestrator.enqueue_timestamp self.dequeue_latency.set(latency) # type: ignore[attr-defined] if self.ENABLE_INSTRUMENTATION: diff --git a/matrix/agents/dispatcher.py b/matrix/agents/dispatcher.py index da41076..096ceaa 100644 --- a/matrix/agents/dispatcher.py +++ b/matrix/agents/dispatcher.py @@ -31,6 +31,9 @@ class Dispatcher: any previously checked-out orchestrators for that agent. """ + # Sentinel object returned by checkout() to signal agents to stop + SHUTDOWN_SENTINEL = "SHUTDOWN" + def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str): self.role = role self.sink = sink @@ -39,6 +42,7 @@ def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str): self.incoming_queue: asyncio.Queue = asyncio.Queue() # orch_id -> (agent_ray_name, orchestrator) self.checked_out: Dict[str, tuple[str, Orchestrator]] = {} + self._shutting_down: bool = False # Other role Dispatchers for forwarding (handles are stable — dispatchers don't restart) self.dispatchers: Dict[str, ray.actor.ActorHandle] = {} @@ -66,13 +70,16 @@ async def enqueue(self, orchestrator: Orchestrator): async def checkout(self, agent_ray_name: str) -> Optional[Orchestrator]: """ - Agent pulls work, blocks until available. - Returns None as shutdown sentinel. + Agent pulls work. Returns immediately with an orchestrator or None + if the queue is empty. Agents poll with sleep on their side. + Returns SHUTDOWN_SENTINEL when the dispatcher is shutting down. """ - orchestrator = await self.incoming_queue.get() - self.queue_size.set(self.incoming_queue.qsize()) - if orchestrator is None: + if self._shutting_down: + return self.SHUTDOWN_SENTINEL + if self.incoming_queue.empty(): return None + orchestrator = self.incoming_queue.get_nowait() + self.queue_size.set(self.incoming_queue.qsize()) self.checked_out[orchestrator.id] = (agent_ray_name, orchestrator) return orchestrator @@ -149,9 +156,9 @@ async def _send_to_sink(self, orchestrator: Orchestrator): ) async def shutdown(self): - """Put None sentinels in queue (one per known agent) to stop their loops.""" - for _ in self.known_agents: - await self.incoming_queue.put(None) + """Signal all agents to stop by setting the shutdown flag. + Agents will receive SHUTDOWN_SENTINEL on their next checkout call.""" + self._shutting_down = True async def check_health(self): return True From 985c7151110129138a64e606c09cd59dd111d600 Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Mon, 16 Mar 2026 23:11:56 +0000 Subject: [PATCH 4/7] change to dispatch push --- matrix/agents/agent_actor.py | 98 ++++++++------------------- matrix/agents/dispatcher.py | 127 +++++++++++++++++++++++++---------- 2 files changed, 120 insertions(+), 105 deletions(-) diff --git a/matrix/agents/agent_actor.py b/matrix/agents/agent_actor.py index 96370e2..d68eff5 100644 --- a/matrix/agents/agent_actor.py +++ b/matrix/agents/agent_actor.py @@ -17,7 +17,6 @@ from ray.util.metrics import Counter, Gauge from .agent_utils import remote_call_with_retry, setup_logging -from .dispatcher import Dispatcher from .orchestrator import BaseResourceClient, DeadOrchestrator, Orchestrator logger = logging.getLogger(__name__) @@ -145,18 +144,14 @@ def __init__( "num of kb for the serialized message object", {}, ), + ( + "queue_size", + Gauge, + "agent_queue_size", + "Current queue size for this agent", + {}, + ), ] - # queue_size only makes sense for Sink (local queue); dispatched agents use Dispatcher's metric - if self.dispatcher_name is None: - metrics_config.append( - ( - "queue_size", - Gauge, - "agent_queue_size", - "Current queue size for this agent", - {}, - ), - ) self._init_metrics(metrics_config) @staticmethod @@ -270,8 +265,8 @@ def _log_exceptions(task): self.tasks_completed.inc() # type: ignore[attr-defined] self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] + # Resolve dispatcher for submit routing (if this agent has one) if self.dispatcher_name is not None: - # Pull-based mode: resolve dispatcher by name, notify of (re)start try: self.dispatcher = ray.get_actor(self.dispatcher_name, namespace=self.namespace) except Exception as e: @@ -280,72 +275,37 @@ def _log_exceptions(task): ) return try: - await self.dispatcher.agent_started.remote(self.ray_name) + self_handle = ray.get_runtime_context().current_actor + await self.dispatcher.agent_started.remote(self.ray_name, self_handle) except Exception as e: self.logger.error( f"Agent {self.id} failed to call agent_started on dispatcher: {repr(e)}" ) return - while self.running: - try: - result = await self.dispatcher.checkout.remote(self.ray_name) - except Exception as e: - self.logger.error( - f"Agent {self.id} checkout failed: {repr(e)}, retrying..." - ) - await asyncio.sleep(1.0) - continue - if result == Dispatcher.SHUTDOWN_SENTINEL: - break - if result is None: - # Queue empty, poll again after a short sleep - await asyncio.sleep(0.1) - continue - orchestrator = result - latency = time.time() - orchestrator.enqueue_timestamp - self.dequeue_latency.set(latency) # type: ignore[attr-defined] - if self.ENABLE_INSTRUMENTATION: - orchestrator.append_instrumentation( - self.dequeue_latency, self.agent_id, latency # type: ignore[attr-defined] - ) - - task = asyncio.create_task(self._handle(orchestrator)) - task._orchestrator = orchestrator # type: ignore[attr-defined] - - self.pending_tasks.add(task) - self.tasks_started.inc() # type: ignore[attr-defined] - self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] - - task.add_done_callback(self.pending_tasks.discard) - task.add_done_callback(_log_exceptions) - else: - # Local-queue mode (Sink) - while self.running: - orchestrator = await self.queue.get() - if orchestrator is None: # Shutdown sentinel - break - latency = time.time() - orchestrator.enqueue_timestamp - self.dequeue_latency.set(latency) # type: ignore[attr-defined] - if self.ENABLE_INSTRUMENTATION: - orchestrator.append_instrumentation( - self.dequeue_latency, self.agent_id, latency # type: ignore[attr-defined] - ) + # All agents: read from local queue (dispatcher pushes here, or direct for Sink) + while self.running: + orchestrator = await self.queue.get() + if orchestrator is None: # Shutdown sentinel + break + latency = time.time() - orchestrator.enqueue_timestamp + self.dequeue_latency.set(latency) # type: ignore[attr-defined] + if self.ENABLE_INSTRUMENTATION: + orchestrator.append_instrumentation( + self.dequeue_latency, self.agent_id, latency # type: ignore[attr-defined] + ) - # Update queue size after getting message - self.queue_size.set(self.queue.qsize()) # type: ignore[attr-defined] + self.queue_size.set(self.queue.qsize()) # type: ignore[attr-defined] - task = asyncio.create_task(self._handle(orchestrator)) - # Attach orchestrator to task for error logging - task._orchestrator = orchestrator # type: ignore[attr-defined] + task = asyncio.create_task(self._handle(orchestrator)) + task._orchestrator = orchestrator # type: ignore[attr-defined] - self.pending_tasks.add(task) - self.tasks_started.inc() # type: ignore[attr-defined] - self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] + self.pending_tasks.add(task) + self.tasks_started.inc() # type: ignore[attr-defined] + self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] - # Clean up completed tasks - task.add_done_callback(self.pending_tasks.discard) - task.add_done_callback(_log_exceptions) + task.add_done_callback(self.pending_tasks.discard) + task.add_done_callback(_log_exceptions) if self.pending_tasks: await asyncio.gather(*self.pending_tasks, return_exceptions=True) diff --git a/matrix/agents/dispatcher.py b/matrix/agents/dispatcher.py index 096ceaa..c22e12f 100644 --- a/matrix/agents/dispatcher.py +++ b/matrix/agents/dispatcher.py @@ -22,18 +22,15 @@ class Dispatcher: """ Dispatcher actor (one per role) that acts as a message broker. - Agents pull work from their Dispatcher via checkout(), process it, - and submit results back via submit(). The Dispatcher tracks exactly - which agent has which orchestrator. + Receives orchestrators via enqueue(), pushes them to agents via + receive_message() round-robin, and routes completed work to the next + Dispatcher or Sink. Dead-agent detection is handled by agent restarts: when an agent restarts (max_restarts=-1), it calls agent_started() which tombstones any previously checked-out orchestrators for that agent. """ - # Sentinel object returned by checkout() to signal agents to stop - SHUTDOWN_SENTINEL = "SHUTDOWN" - def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str): self.role = role self.sink = sink @@ -42,12 +39,13 @@ def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str): self.incoming_queue: asyncio.Queue = asyncio.Queue() # orch_id -> (agent_ray_name, orchestrator) self.checked_out: Dict[str, tuple[str, Orchestrator]] = {} - self._shutting_down: bool = False # Other role Dispatchers for forwarding (handles are stable — dispatchers don't restart) self.dispatchers: Dict[str, ray.actor.ActorHandle] = {} - # Known agents (ray actor names), populated by agent_started() - self.known_agents: set[str] = set() + # Agent handles for push-based delivery, populated by set_agents() + self.agents: Dict[str, ray.actor.ActorHandle] = {} + self._agent_names: list[str] = [] + self._next_idx: int = 0 self.logger = logging.getLogger(f"Dispatcher[{role}]") @@ -58,6 +56,10 @@ def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str): ) self.queue_size.set_default_tags({"role": role}) + # Event loop waits for first agent to register via agent_started() + self._agents_ready = asyncio.Event() + self._loop_task = asyncio.get_event_loop().create_task(self._event_loop()) + async def set_dispatchers(self, dispatchers: Dict[str, ray.actor.ActorHandle]): """Wire this Dispatcher to other role Dispatchers for forwarding.""" self.dispatchers = dispatchers @@ -68,20 +70,54 @@ async def enqueue(self, orchestrator: Orchestrator): await self.incoming_queue.put(orchestrator) self.queue_size.set(self.incoming_queue.qsize()) - async def checkout(self, agent_ray_name: str) -> Optional[Orchestrator]: - """ - Agent pulls work. Returns immediately with an orchestrator or None - if the queue is empty. Agents poll with sleep on their side. - Returns SHUTDOWN_SENTINEL when the dispatcher is shutting down. - """ - if self._shutting_down: - return self.SHUTDOWN_SENTINEL - if self.incoming_queue.empty(): - return None - orchestrator = self.incoming_queue.get_nowait() - self.queue_size.set(self.incoming_queue.qsize()) - self.checked_out[orchestrator.id] = (agent_ray_name, orchestrator) - return orchestrator + async def _event_loop(self): + """Dequeue orchestrators and push to agents round-robin.""" + await self._agents_ready.wait() + while True: + orchestrator = await self.incoming_queue.get() + if orchestrator is None: # shutdown sentinel + break + self.queue_size.set(self.incoming_queue.qsize()) + await self._push_to_agent(orchestrator) + + async def _push_to_agent(self, orchestrator: Orchestrator): + """Push orchestrator to an agent with round-robin and retry.""" + if not self._agent_names: + self.logger.error(f"No agents registered for role {self.role}") + dead = DeadOrchestrator( + orchestrator.id, error=f"No agents for role {self.role}" + ) + await self._send_to_sink(dead) + return + + num_agents = len(self._agent_names) + for attempt in range(num_agents): + idx = (self._next_idx + attempt) % num_agents + agent_name = self._agent_names[idx] + agent_handle = self.agents[agent_name] + try: + await remote_call_with_retry( + agent_handle.receive_message, + orchestrator, + logger=self.logger, + ) + self._next_idx = (idx + 1) % num_agents + self.checked_out[orchestrator.id] = (agent_name, orchestrator) + return + except Exception as e: + self.logger.warning( + f"Failed to push to agent {agent_name}: {repr(e)}" + ) + continue + + # All agents failed + self.logger.error( + f"All agents unreachable for orch {orchestrator.id}, tombstoning" + ) + dead = DeadOrchestrator( + orchestrator.id, error=f"All agents for {self.role} unreachable" + ) + await self._send_to_sink(dead) async def submit(self, processed_orch: Orchestrator, orch_id: str): """ @@ -123,19 +159,30 @@ async def submit_error(self, orchestrator: Orchestrator, orch_id: str): self.checked_out.pop(orch_id, None) await self._send_to_sink(orchestrator) - async def agent_started(self, agent_ray_name: str): + async def agent_started(self, agent_ray_name: str, agent_handle: ray.actor.ActorHandle): """ - Called on agent (re)start. Registers the agent and tombstones any - previously checked-out orchestrators for it (handles restart race). + Called on agent (re)start. Updates the agent handle and tombstones + any previously checked-out orchestrators for this agent. """ - self.known_agents.add(agent_ray_name) - to_tombstone = [ - (oid, orch) - for oid, (aname, orch) in self.checked_out.items() - if aname == agent_ray_name - ] + # Register / update agent handle + self.agents[agent_ray_name] = agent_handle + if agent_ray_name not in self._agent_names: + self._agent_names.append(agent_ray_name) + if not self._agents_ready.is_set(): + self._agents_ready.set() + + # Atomically remove all checked-out entries for this agent (no await + # between removals, so _push_to_agent cannot interleave here) + to_tombstone = [] + for oid in list(self.checked_out): + aname, orch = self.checked_out[oid] + if aname == agent_ray_name: + del self.checked_out[oid] + to_tombstone.append((oid, orch)) + + # Now send tombstones — _push_to_agent may interleave at these await + # points, but any new checked_out entries are for the new agent instance for oid, orch in to_tombstone: - del self.checked_out[oid] self.logger.warning( f"Agent {agent_ray_name} restarted, tombstoning orch {oid}" ) @@ -156,9 +203,17 @@ async def _send_to_sink(self, orchestrator: Orchestrator): ) async def shutdown(self): - """Signal all agents to stop by setting the shutdown flag. - Agents will receive SHUTDOWN_SENTINEL on their next checkout call.""" - self._shutting_down = True + """Stop event loop and signal all agents to shut down.""" + # Stop the event loop + await self.incoming_queue.put(None) + # Signal each agent to shut down + for agent_name, agent_handle in self.agents.items(): + try: + await agent_handle.shutdown.remote() + except Exception as e: + self.logger.warning( + f"Error shutting down agent {agent_name}: {repr(e)}" + ) async def check_health(self): return True From c7e7f3d947cef868191b9af58575f48139d94b7f Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Tue, 17 Mar 2026 02:07:06 +0000 Subject: [PATCH 5/7] fix bug that agents got gc --- matrix/agents/agent_actor.py | 14 ++--- matrix/agents/agent_utils.py | 93 ++++++---------------------- matrix/agents/config/simulation.yaml | 10 +-- matrix/agents/dispatcher.py | 19 +++--- matrix/agents/p2p_agents.py | 17 +++-- 5 files changed, 53 insertions(+), 100 deletions(-) diff --git a/matrix/agents/agent_actor.py b/matrix/agents/agent_actor.py index d68eff5..a73db4b 100644 --- a/matrix/agents/agent_actor.py +++ b/matrix/agents/agent_actor.py @@ -16,10 +16,9 @@ from omegaconf import DictConfig from ray.util.metrics import Counter, Gauge -from .agent_utils import remote_call_with_retry, setup_logging +from .agent_utils import send_with_retry, setup_logging from .orchestrator import BaseResourceClient, DeadOrchestrator, Orchestrator -logger = logging.getLogger(__name__) # ==== Abstract AgentActor ==== @@ -46,7 +45,6 @@ def __init__( self.config = config self.resource_name = config.get("resource_name") if self.resource_name: - logger.debug(f"Resources {list(resources.keys())}") self.resource_client: Optional[BaseResourceClient] = resources[ self.resource_name ] @@ -227,7 +225,7 @@ async def _handle_task_exception(orchestrator, msg): ) if self.dispatcher is not None: try: - await remote_call_with_retry( + await send_with_retry( self.dispatcher.submit_error, orchestrator, orchestrator.id, @@ -239,7 +237,7 @@ async def _handle_task_exception(orchestrator, msg): ) # Fall back to sending directly to sink try: - await remote_call_with_retry( + await send_with_retry( self.sink.receive_message, orchestrator, logger=self.logger ) except Exception: @@ -275,7 +273,7 @@ def _log_exceptions(task): ) return try: - self_handle = ray.get_runtime_context().current_actor + self_handle = ray.get_runtime_context().current_actor await self.dispatcher.agent_started.remote(self.ray_name, self_handle) except Exception as e: self.logger.error( @@ -336,7 +334,7 @@ async def _handle(self, orchestrator: Orchestrator): if self.dispatcher is not None: # Submit to dispatcher for deterministic routing try: - await remote_call_with_retry( + await send_with_retry( self.dispatcher.submit, next_state, orchestrator.id, @@ -351,7 +349,7 @@ async def _handle(self, orchestrator: Orchestrator): error=f"Dispatcher unreachable: {e}", ) try: - await remote_call_with_retry( + await send_with_retry( self.sink.receive_message, dead, logger=self.logger ) except Exception: diff --git a/matrix/agents/agent_utils.py b/matrix/agents/agent_utils.py index 40bb321..b77551a 100644 --- a/matrix/agents/agent_utils.py +++ b/matrix/agents/agent_utils.py @@ -9,7 +9,6 @@ import json import logging import os -import random import re import time from collections import namedtuple @@ -182,111 +181,57 @@ class HistPair(NamedTuple): REMOTE_CALL_TIMEOUT = 60.0 # seconds per attempt REMOTE_CALL_RETRIES = 3 -REMOTE_CALL_BACKOFF = 2.0 # seconds between retries -async def remote_call_with_retry( +async def send_with_retry( method, *args, - retries: int = REMOTE_CALL_RETRIES, timeout: float = REMOTE_CALL_TIMEOUT, - backoff: float = REMOTE_CALL_BACKOFF, + max_retries: int = REMOTE_CALL_RETRIES, logger: Optional[logging.Logger] = None, ): """Call a Ray remote method with retry and timeout. + Only retries on RayActorError (actor dead/restarting) and TimeoutError. + Other exceptions are raised immediately. + Args: method: Ray actor method (e.g. handle.submit) *args: Arguments to pass to method.remote(*args) - retries: Number of attempts timeout: Timeout per attempt in seconds - backoff: Seconds to wait between retries + max_retries: Maximum retry attempts logger: Logger for warnings Returns: The result of the remote call - Raises: - The last exception if all retries exhausted - """ - last_error = None - for attempt in range(retries): - try: - return await asyncio.wait_for(method.remote(*args), timeout=timeout) - except Exception as e: - last_error = e - if logger: - logger.warning( - f"Remote call failed (attempt {attempt + 1}/{retries}): {repr(e)}" - ) - if attempt < retries - 1: - await asyncio.sleep(backoff) - raise last_error # type: ignore[misc] - - -# ==== Utility Functions ==== -async def send_with_retry( - orchestrator: "Orchestrator", - role: str, - sink: ray.actor.ActorHandle, - local_cache: Dict[str, List[ray.actor.ActorHandle]], - log: logging.Logger, - timeout: float = 60.0, - max_retries: int = 3, -) -> Dict[str, List[ray.actor.ActorHandle]]: - """ - Send orchestrator to an agent with local cache and fault-tolerant retry. - - Args: - orchestrator: The orchestrator state to send - role: The role name of the target agent - sink: The sink actor handle for registry lookups - local_cache: Local team cache dict (will be updated on refresh) - log: Logger instance for warnings - timeout: Timeout for actor acquisition - max_retries: Maximum retry attempts - - Returns: - Updated local_cache dict - Raises: RuntimeError: If all retries exhausted + Exception: Non-retryable exceptions are raised immediately """ last_exception = None for attempt in range(max_retries): try: - if role == "_sink": - agent = sink - elif attempt == 0 and role in local_cache and local_cache[role]: - # First attempt: use local cache for speed - agent = random.choice(local_cache[role]) - else: - # Fallback: get from sink with force refresh and update local cache - agent = await sink.get_actor.remote(role, timeout, True) - local_cache = await sink.get_team_snapshot.remote() - - await agent.receive_message.remote(orchestrator) - return local_cache # Success + return await asyncio.wait_for(method.remote(*args), timeout=timeout) except ray.exceptions.RayActorError as e: last_exception = e - log.warning( - f"Actor {role} is dead (attempt {attempt + 1}/{max_retries}): {repr(e)}" - ) - # Clear local cache for this role to force refresh - local_cache.pop(role, None) + if logger: + logger.warning( + f"Actor dead (attempt {attempt + 1}/{max_retries}): {repr(e)}" + ) continue except TimeoutError as e: - last_exception = e # type: ignore[assignment] - log.warning( - f"Timeout getting actor {role} (attempt {attempt + 1}/{max_retries}): {repr(e)}" - ) + last_exception = e + if logger: + logger.warning( + f"Timeout (attempt {attempt + 1}/{max_retries}): {repr(e)}" + ) continue except Exception: - # For other exceptions, don't retry + # Non-retryable, raise immediately raise - # All retries exhausted raise RuntimeError( - f"Failed to send to {role} after {max_retries} attempts: {last_exception}" + f"Remote call failed after {max_retries} attempts: {last_exception}" ) diff --git a/matrix/agents/config/simulation.yaml b/matrix/agents/config/simulation.yaml index a8dbc28..1cdad0b 100644 --- a/matrix/agents/config/simulation.yaml +++ b/matrix/agents/config/simulation.yaml @@ -7,15 +7,17 @@ matrix: max_concurrent_tasks: 100 rate_limit_enqueue: false -# Dispatcher settings -dispatcher: - ray_resources: - num_cpus: 1 debug: false seed: 42 num_trial: 1 +# Dispatcher settings +dispatcher: + debug: ${debug} + ray_resources: + num_cpus: 1 + # Default output settings output: path: "simulation_results.jsonl" diff --git a/matrix/agents/dispatcher.py b/matrix/agents/dispatcher.py index c22e12f..44e0d6d 100644 --- a/matrix/agents/dispatcher.py +++ b/matrix/agents/dispatcher.py @@ -12,12 +12,9 @@ import ray from ray.util.metrics import Gauge -from .agent_utils import remote_call_with_retry +from .agent_utils import send_with_retry, setup_logging from .orchestrator import DeadOrchestrator, Orchestrator -logger = logging.getLogger(__name__) - - class Dispatcher: """ Dispatcher actor (one per role) that acts as a message broker. @@ -31,7 +28,7 @@ class Dispatcher: any previously checked-out orchestrators for that agent. """ - def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str): + def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str, debug: bool = False): self.role = role self.sink = sink self.namespace = namespace @@ -48,6 +45,7 @@ def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str): self._next_idx: int = 0 self.logger = logging.getLogger(f"Dispatcher[{role}]") + setup_logging(self.logger, debug) self.queue_size = Gauge( "dispatcher_queue_size", @@ -73,6 +71,7 @@ async def enqueue(self, orchestrator: Orchestrator): async def _event_loop(self): """Dequeue orchestrators and push to agents round-robin.""" await self._agents_ready.wait() + self.logger.debug("{self.role} run event loop") while True: orchestrator = await self.incoming_queue.get() if orchestrator is None: # shutdown sentinel @@ -96,7 +95,7 @@ async def _push_to_agent(self, orchestrator: Orchestrator): agent_name = self._agent_names[idx] agent_handle = self.agents[agent_name] try: - await remote_call_with_retry( + await send_with_retry( agent_handle.receive_message, orchestrator, logger=self.logger, @@ -106,7 +105,8 @@ async def _push_to_agent(self, orchestrator: Orchestrator): return except Exception as e: self.logger.warning( - f"Failed to push to agent {agent_name}: {repr(e)}" + f"Failed to push to agent {agent_name}: {type(e).__name__}: {e}", + exc_info=True, ) continue @@ -134,7 +134,7 @@ async def submit(self, processed_orch: Orchestrator, orch_id: str): await self._send_to_sink(processed_orch) elif next_role in self.dispatchers: try: - await remote_call_with_retry( + await send_with_retry( self.dispatchers[next_role].enqueue, processed_orch, logger=self.logger, @@ -165,6 +165,7 @@ async def agent_started(self, agent_ray_name: str, agent_handle: ray.actor.Actor any previously checked-out orchestrators for this agent. """ # Register / update agent handle + self.logger.debug(f"Agent_started {agent_ray_name} {agent_handle}") self.agents[agent_ray_name] = agent_handle if agent_ray_name not in self._agent_names: self._agent_names.append(agent_ray_name) @@ -194,7 +195,7 @@ async def agent_started(self, agent_ray_name: str, agent_handle: ray.actor.Actor async def _send_to_sink(self, orchestrator: Orchestrator): """Send to sink with retry. If sink is unreachable, log and drop.""" try: - await remote_call_with_retry( + await send_with_retry( self.sink.receive_message, orchestrator, logger=self.logger ) except Exception as e: diff --git a/matrix/agents/p2p_agents.py b/matrix/agents/p2p_agents.py index c7c259b..a363538 100644 --- a/matrix/agents/p2p_agents.py +++ b/matrix/agents/p2p_agents.py @@ -97,12 +97,17 @@ def __init__(self, simulation_id: str, dispatcher_config: Optional[DictConfig] = self.sink: Optional[ray.actor.ActorHandle] = None # Dispatcher per role (excluding sink) self.dispatchers: Dict[str, ray.actor.ActorHandle] = {} + # Agent handles per role (must be kept alive to prevent GC) + self.agents: Dict[str, List[ray.actor.ActorHandle]] = {} # Dispatcher ray resource config self.dispatcher_ray_resources: dict[str, Any] = {} - if dispatcher_config and "ray_resources" in dispatcher_config: - self.dispatcher_ray_resources = OmegaConf.to_container( - dispatcher_config["ray_resources"], resolve=True - ) + self.dispatcher_debug: bool = False + if dispatcher_config: + if "ray_resources" in dispatcher_config: + self.dispatcher_ray_resources = OmegaConf.to_container( + dispatcher_config["ray_resources"], resolve=True + ) + self.dispatcher_debug = dispatcher_config.get("debug", False) def create_role(self, role_name: str, agent_config: DictConfig, resources): """Create agents for a role. _sink must be created first.""" @@ -132,7 +137,7 @@ def create_role(self, role_name: str, agent_config: DictConfig, resources): namespace=self.simulation_id, max_restarts=0, **self.dispatcher_ray_resources, - ).remote(role=role_name, sink=self.sink, namespace=self.simulation_id) + ).remote(role=role_name, debug=self.dispatcher_debug, sink=self.sink, namespace=self.simulation_id) self.dispatchers[role_name] = dispatcher logger.info(f"Created dispatcher for role: {role_name}") @@ -166,6 +171,8 @@ def create_role(self, role_name: str, agent_config: DictConfig, resources): if is_sink: self.sink = agents[0] + self.agents[role_name] = agents + self.teamConfig[role_name] = ( agent_class.__ray_metadata__.modified_class, agent_config, From 3e0a8245302cc9b0e4b44fa34a3c0dfccd09a86d Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Tue, 17 Mar 2026 16:27:39 +0000 Subject: [PATCH 6/7] sleep when actor is down --- matrix/agents/agent_utils.py | 1 + matrix/agents/examples/tau2_bench.py | 4 ++++ matrix/scripts/kill_ray_actor.py | 20 ++++++++++++++++++-- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/matrix/agents/agent_utils.py b/matrix/agents/agent_utils.py index b77551a..5615e8a 100644 --- a/matrix/agents/agent_utils.py +++ b/matrix/agents/agent_utils.py @@ -220,6 +220,7 @@ async def send_with_retry( logger.warning( f"Actor dead (attempt {attempt + 1}/{max_retries}): {repr(e)}" ) + await asyncio.sleep(10) continue except TimeoutError as e: last_exception = e diff --git a/matrix/agents/examples/tau2_bench.py b/matrix/agents/examples/tau2_bench.py index 907e5b8..4fd3383 100644 --- a/matrix/agents/examples/tau2_bench.py +++ b/matrix/agents/examples/tau2_bench.py @@ -181,6 +181,7 @@ def __init__( config: DictConfig, resources: dict[str, BaseResourceClient], sink=None, + **kwargs, ): super().__init__( id, @@ -188,6 +189,7 @@ def __init__( config, resources=resources, sink=sink, + **kwargs, ) if agent_id == "llm_agent": # instantiate the template @@ -357,6 +359,7 @@ def __init__( config: DictConfig, resources: dict[str, BaseResourceClient], sink=None, + **kwargs, ): super().__init__( id, @@ -364,6 +367,7 @@ def __init__( config, resources=resources, sink=sink, + **kwargs, ) self.tmp_dir = os.path.abspath(os.path.expanduser(config["tmp_dir"])) diff --git a/matrix/scripts/kill_ray_actor.py b/matrix/scripts/kill_ray_actor.py index 7b7554f..796fa98 100644 --- a/matrix/scripts/kill_ray_actor.py +++ b/matrix/scripts/kill_ray_actor.py @@ -153,9 +153,25 @@ def kill_random_repeatedly( print() kill_count = 0 + # Pre-determine the kill order: shuffle all actors, then cycle through + if max_kills is not None: + kill_order = actor_names * ((max_kills // len(actor_names)) + 1) + random.shuffle(kill_order) + kill_order = kill_order[:max_kills] + else: + kill_order = None + + idx = 0 while max_kills is None or kill_count < max_kills: - # Pick a random actor - target = random.choice(actor_names) + if kill_order is not None: + target = kill_order[idx] + else: + # For unlimited mode, shuffle a full round then repeat + if idx % len(actor_names) == 0: + current_round = list(actor_names) + random.shuffle(current_round) + target = current_round[idx % len(actor_names)] + idx += 1 print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Attempting to kill: {target}") # Kill remotely From 264201ce256c9ca6ba0fb1dbd51b1c37789a351b Mon Sep 17 00:00:00 2001 From: Dong Wang Date: Wed, 18 Mar 2026 03:26:39 +0000 Subject: [PATCH 7/7] ignore a non-existent submit --- matrix/agents/agent_actor.py | 14 +++++++++----- matrix/agents/agent_utils.py | 2 +- matrix/agents/dispatcher.py | 27 ++++++++++++++++++++++----- matrix/agents/p2p_agents.py | 29 +++++++++++++++++++++-------- matrix/agents/sink.py | 29 +++++++++-------------------- 5 files changed, 62 insertions(+), 39 deletions(-) diff --git a/matrix/agents/agent_actor.py b/matrix/agents/agent_actor.py index a73db4b..7756d02 100644 --- a/matrix/agents/agent_actor.py +++ b/matrix/agents/agent_actor.py @@ -20,7 +20,6 @@ from .orchestrator import BaseResourceClient, DeadOrchestrator, Orchestrator - # ==== Abstract AgentActor ==== # @ray.remote class AgentActor(abc.ABC): @@ -69,7 +68,9 @@ def __init__( # Dispatcher name + namespace for name-based resolution (None for Sink) self.dispatcher_name = dispatcher_name self.namespace = namespace - self.ray_name = ray_name # This agent's Ray actor name (for dispatcher identification) + self.ray_name = ( + ray_name # This agent's Ray actor name (for dispatcher identification) + ) self.dispatcher = None # resolved lazily in _event_loop self.event_loop_task: Optional[asyncio.Task] = ( @@ -245,7 +246,7 @@ async def _handle_task_exception(orchestrator, msg): f"Failed to send error orch {orchestrator.id} to sink, dropping" ) else: - await self.sink.receive_message.remote(orchestrator) + await self.receive_message(orchestrator) def _log_exceptions(task): try: @@ -266,14 +267,17 @@ def _log_exceptions(task): # Resolve dispatcher for submit routing (if this agent has one) if self.dispatcher_name is not None: try: - self.dispatcher = ray.get_actor(self.dispatcher_name, namespace=self.namespace) + self.dispatcher = ray.get_actor( + self.dispatcher_name, namespace=self.namespace + ) + assert self.dispatcher is not None except Exception as e: self.logger.error( f"Agent {self.id} failed to find dispatcher {self.dispatcher_name}: {repr(e)}" ) return try: - self_handle = ray.get_runtime_context().current_actor + self_handle = ray.get_runtime_context().current_actor await self.dispatcher.agent_started.remote(self.ray_name, self_handle) except Exception as e: self.logger.error( diff --git a/matrix/agents/agent_utils.py b/matrix/agents/agent_utils.py index 5615e8a..9568423 100644 --- a/matrix/agents/agent_utils.py +++ b/matrix/agents/agent_utils.py @@ -209,7 +209,7 @@ async def send_with_retry( RuntimeError: If all retries exhausted Exception: Non-retryable exceptions are raised immediately """ - last_exception = None + last_exception: Optional[Exception] = None for attempt in range(max_retries): try: diff --git a/matrix/agents/dispatcher.py b/matrix/agents/dispatcher.py index 44e0d6d..e462067 100644 --- a/matrix/agents/dispatcher.py +++ b/matrix/agents/dispatcher.py @@ -15,6 +15,7 @@ from .agent_utils import send_with_retry, setup_logging from .orchestrator import DeadOrchestrator, Orchestrator + class Dispatcher: """ Dispatcher actor (one per role) that acts as a message broker. @@ -28,7 +29,13 @@ class Dispatcher: any previously checked-out orchestrators for that agent. """ - def __init__(self, role: str, sink: ray.actor.ActorHandle, namespace: str, debug: bool = False): + def __init__( + self, + role: str, + sink: ray.actor.ActorHandle, + namespace: str, + debug: bool = False, + ): self.role = role self.sink = sink self.namespace = namespace @@ -95,13 +102,13 @@ async def _push_to_agent(self, orchestrator: Orchestrator): agent_name = self._agent_names[idx] agent_handle = self.agents[agent_name] try: + self.checked_out[orchestrator.id] = (agent_name, orchestrator) await send_with_retry( agent_handle.receive_message, orchestrator, logger=self.logger, ) self._next_idx = (idx + 1) % num_agents - self.checked_out[orchestrator.id] = (agent_name, orchestrator) return except Exception as e: self.logger.warning( @@ -114,6 +121,7 @@ async def _push_to_agent(self, orchestrator: Orchestrator): self.logger.error( f"All agents unreachable for orch {orchestrator.id}, tombstoning" ) + self.checked_out.pop(orchestrator.id) dead = DeadOrchestrator( orchestrator.id, error=f"All agents for {self.role} unreachable" ) @@ -124,7 +132,14 @@ async def submit(self, processed_orch: Orchestrator, orch_id: str): Agent acks completion. Dispatcher removes from checked_out, checks is_done()/current_agent(), forwards to target Dispatcher or Sink. """ - self.checked_out.pop(orch_id, None) + if orch_id not in self.checked_out: + # race condition, dispatch already marked it died + self.logger.warning( + f"Dispatcher {self.role} failed to find {orch_id}, ignore" + ) + return + + self.checked_out.pop(orch_id) if await processed_orch.is_done(): await self._send_to_sink(processed_orch) @@ -156,10 +171,12 @@ async def submit(self, processed_orch: Orchestrator, orch_id: str): async def submit_error(self, orchestrator: Orchestrator, orch_id: str): """Agent error ack. Forward directly to Sink.""" - self.checked_out.pop(orch_id, None) + self.checked_out.pop(orch_id) await self._send_to_sink(orchestrator) - async def agent_started(self, agent_ray_name: str, agent_handle: ray.actor.ActorHandle): + async def agent_started( + self, agent_ray_name: str, agent_handle: ray.actor.ActorHandle + ): """ Called on agent (re)start. Updates the agent handle and tombstones any previously checked-out orchestrators for this agent. diff --git a/matrix/agents/p2p_agents.py b/matrix/agents/p2p_agents.py index a363538..f277805 100644 --- a/matrix/agents/p2p_agents.py +++ b/matrix/agents/p2p_agents.py @@ -36,13 +36,13 @@ setup_logging, ) from .dataset_loader import BaseDatasetLoader +from .dispatcher import Dispatcher from .orchestrator import ( BaseResourceClient, DeadOrchestrator, Orchestrator, SequentialOrchestrator, ) -from .dispatcher import Dispatcher from .sink import Sink # Re-export all public names for backward compatibility @@ -90,7 +90,9 @@ def done(self): class ScalableTeamManager: """Manages teams with multiple actors per role using Dispatchers for routing and load balancing""" - def __init__(self, simulation_id: str, dispatcher_config: Optional[DictConfig] = None): + def __init__( + self, simulation_id: str, dispatcher_config: Optional[DictConfig] = None + ): self.simulation_id = simulation_id self.teamConfig: Dict[str, Tuple[Type, DictConfig]] = {} # Sink actor handle - must be created first @@ -104,7 +106,7 @@ def __init__(self, simulation_id: str, dispatcher_config: Optional[DictConfig] = self.dispatcher_debug: bool = False if dispatcher_config: if "ray_resources" in dispatcher_config: - self.dispatcher_ray_resources = OmegaConf.to_container( + self.dispatcher_ray_resources = OmegaConf.to_container( # type: ignore [assignment] dispatcher_config["ray_resources"], resolve=True ) self.dispatcher_debug = dispatcher_config.get("debug", False) @@ -137,8 +139,13 @@ def create_role(self, role_name: str, agent_config: DictConfig, resources): namespace=self.simulation_id, max_restarts=0, **self.dispatcher_ray_resources, - ).remote(role=role_name, debug=self.dispatcher_debug, sink=self.sink, namespace=self.simulation_id) - self.dispatchers[role_name] = dispatcher + ).remote( + role=role_name, + debug=self.dispatcher_debug, + sink=self.sink, + namespace=self.simulation_id, + ) + self.dispatchers[role_name] = dispatcher # type: ignore[assignment] logger.info(f"Created dispatcher for role: {role_name}") agents = [] @@ -192,7 +199,9 @@ async def initialize_team(self, team: Dict[str, List[ray.actor.ActorHandle]]): all_actors.extend(role_handles) all_actors.extend(self.dispatchers.values()) - logger.info(f"Checking Ray actor health for {len(all_actors)} actors (including dispatchers)") + logger.info( + f"Checking Ray actor health for {len(all_actors)} actors (including dispatchers)" + ) try: await asyncio.wait_for( asyncio.gather( @@ -381,9 +390,13 @@ async def _process_item(self, trial_item: Tuple[int, Dict[str, Any]]): # Enqueue to the first agent's Dispatcher try: - await self.team_manager.dispatchers[first_agent_role].enqueue.remote(orchestrator) + await self.team_manager.dispatchers[first_agent_role].enqueue.remote( + orchestrator + ) except Exception as e: - logger.error(f"Failed to enqueue to dispatcher for {first_agent_role}: {repr(e)}") + logger.error( + f"Failed to enqueue to dispatcher for {first_agent_role}: {repr(e)}" + ) orchestrator.status["error"] = f"Failed to reach {first_agent_role}: {e}" await self.sink.receive_message.remote(orchestrator) # type: ignore[attr-defined] return diff --git a/matrix/agents/sink.py b/matrix/agents/sink.py index bdeaeb2..14beedd 100644 --- a/matrix/agents/sink.py +++ b/matrix/agents/sink.py @@ -46,9 +46,6 @@ def __init__( self.num_done = 0 self.num_inputs: Optional[int] = None self.ray_objects: dict[str, ray.ObjectRef] = {} # hold the ref to avoid gc - self.pending_writes: int = ( - 0 # Track in-progress writes to avoid closing file prematurely - ) self.num_dead: int = 0 # Counter for dead/lost orchestrators additional_metrics_config: list[tuple[str, type, str, str, dict[str, Any]]] = [ @@ -121,18 +118,14 @@ def _write_output(output_data, output_path): # Run CPU-intensive work in thread pool start_time = time.perf_counter() loop = asyncio.get_event_loop() - self.pending_writes += 1 # Track pending write before yielding - try: - data_to_write = await loop.run_in_executor( - None, - partial( - _write_output, await orchestrator.to_output(), self.output_path - ), - ) - self.output_file.write(data_to_write) - self.sink_write_latency.set(time.perf_counter() - start_time) # type: ignore[attr-defined] - finally: - self.pending_writes -= 1 # Always decrement, even on error + data_to_write = await loop.run_in_executor( + None, + partial( + _write_output, await orchestrator.to_output(), self.output_path + ), + ) + self.output_file.write(data_to_write) + self.sink_write_latency.set(time.perf_counter() - start_time) # type: ignore[attr-defined] # Increment num_done for ALL arrivals (normal + dead) self.num_done += 1 @@ -148,11 +141,7 @@ def _write_output(output_data, output_path): self.metrics_accumulator.accumulate(orchestrator) # Close output file when all tasks are done and no pending writes - if ( - self.num_inputs is not None - and self.num_done >= self.num_inputs - and self.pending_writes == 0 - ): + if self.num_inputs is not None and self.num_done >= self.num_inputs: self.output_file.close() return {"orchestrator": orchestrator}