diff --git a/matrix/agents/agent_actor.py b/matrix/agents/agent_actor.py index 5cca1d1..7756d02 100644 --- a/matrix/agents/agent_actor.py +++ b/matrix/agents/agent_actor.py @@ -17,9 +17,7 @@ from ray.util.metrics import Counter, Gauge from .agent_utils import send_with_retry, setup_logging -from .orchestrator import BaseResourceClient, Orchestrator - -logger = logging.getLogger(__name__) +from .orchestrator import BaseResourceClient, DeadOrchestrator, Orchestrator # ==== Abstract AgentActor ==== @@ -34,6 +32,9 @@ def __init__( config: DictConfig, resources: dict[str, BaseResourceClient], sink: ray.actor.ActorHandle, + dispatcher_name: str = None, + namespace: str = None, + ray_name: str = None, ): # PATCH FIRST - before any HTTP clients are created self._patch_getproxies() @@ -43,7 +44,6 @@ def __init__( self.config = config self.resource_name = config.get("resource_name") if self.resource_name: - logger.debug(f"Resources {list(resources.keys())}") self.resource_client: Optional[BaseResourceClient] = resources[ self.resource_name ] @@ -65,8 +65,13 @@ def __init__( # Store sink reference and start event loop # For Sink actor itself, sink will be None (set later via _set_self_as_sink) self.sink = sink - # Local team cache for fast actor lookup, updated from sink on failure - self._local_team_cache: Dict[str, List[ray.actor.ActorHandle]] = {} + # Dispatcher name + namespace for name-based resolution (None for Sink) + self.dispatcher_name = dispatcher_name + self.namespace = namespace + self.ray_name = ( + ray_name # This agent's Ray actor name (for dispatcher identification) + ) + self.dispatcher = None # resolved lazily in _event_loop self.event_loop_task: Optional[asyncio.Task] = ( asyncio.get_event_loop().create_task(self._event_loop()) @@ -81,13 +86,6 @@ def __init__( "Total number of messages processed by this agent", {}, ), - ( - "queue_size", - Gauge, - "agent_queue_size", - "Current queue size for this agent", - {}, - ), ( "messages_received", Counter, @@ -145,6 +143,13 @@ def __init__( "num of kb for the serialized message object", {}, ), + ( + "queue_size", + Gauge, + "agent_queue_size", + "Current queue size for this agent", + {}, + ), ] self._init_metrics(metrics_config) @@ -219,7 +224,29 @@ async def _handle_task_exception(orchestrator, msg): orchestrator._append( self.agent_id, {"status_ok": False, "error": msg}, self.sink ) - await self.sink.receive_message.remote(orchestrator) + if self.dispatcher is not None: + try: + await send_with_retry( + self.dispatcher.submit_error, + orchestrator, + orchestrator.id, + logger=self.logger, + ) + except Exception as e: + self.logger.error( + f"Failed to submit error to dispatcher for {orchestrator.id}: {repr(e)}" + ) + # Fall back to sending directly to sink + try: + await send_with_retry( + self.sink.receive_message, orchestrator, logger=self.logger + ) + except Exception: + self.logger.error( + f"Failed to send error orch {orchestrator.id} to sink, dropping" + ) + else: + await self.receive_message(orchestrator) def _log_exceptions(task): try: @@ -237,6 +264,28 @@ def _log_exceptions(task): self.tasks_completed.inc() # type: ignore[attr-defined] self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] + # Resolve dispatcher for submit routing (if this agent has one) + if self.dispatcher_name is not None: + try: + self.dispatcher = ray.get_actor( + self.dispatcher_name, namespace=self.namespace + ) + assert self.dispatcher is not None + except Exception as e: + self.logger.error( + f"Agent {self.id} failed to find dispatcher {self.dispatcher_name}: {repr(e)}" + ) + return + try: + self_handle = ray.get_runtime_context().current_actor + await self.dispatcher.agent_started.remote(self.ray_name, self_handle) + except Exception as e: + self.logger.error( + f"Agent {self.id} failed to call agent_started on dispatcher: {repr(e)}" + ) + return + + # All agents: read from local queue (dispatcher pushes here, or direct for Sink) while self.running: orchestrator = await self.queue.get() if orchestrator is None: # Shutdown sentinel @@ -248,18 +297,15 @@ def _log_exceptions(task): self.dequeue_latency, self.agent_id, latency # type: ignore[attr-defined] ) - # Update queue size after getting message self.queue_size.set(self.queue.qsize()) # type: ignore[attr-defined] task = asyncio.create_task(self._handle(orchestrator)) - # Attach orchestrator to task for error logging task._orchestrator = orchestrator # type: ignore[attr-defined] self.pending_tasks.add(task) self.tasks_started.inc() # type: ignore[attr-defined] self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined] - # Clean up completed tasks task.add_done_callback(self.pending_tasks.discard) task.add_done_callback(_log_exceptions) @@ -275,10 +321,6 @@ async def _handle(self, orchestrator: Orchestrator): result = await self.postprocess(orchestrator, result) if self.agent_id != "_sink": next_state = await orchestrator.update(result, self, self.logger) - if await next_state.is_done(): - next_agent_name = "_sink" - else: - next_agent_name = next_state.current_agent() # temporary if self.ENABLE_INSTRUMENTATION: @@ -293,14 +335,31 @@ async def _handle(self, orchestrator: Orchestrator): self.handle_latency, self.agent_id, latency # type: ignore[attr-defined] ) - # Send to next agent with fault-tolerant retry - self._local_team_cache = await send_with_retry( - next_state, - next_agent_name, - self.sink, - self._local_team_cache, - self.logger, - ) + if self.dispatcher is not None: + # Submit to dispatcher for deterministic routing + try: + await send_with_retry( + self.dispatcher.submit, + next_state, + orchestrator.id, + logger=self.logger, + ) + except Exception as e: + self.logger.error( + f"Failed to submit orch {orchestrator.id} to dispatcher: {repr(e)}" + ) + dead = DeadOrchestrator( + orchestrator.id, + error=f"Dispatcher unreachable: {e}", + ) + try: + await send_with_retry( + self.sink.receive_message, dead, logger=self.logger + ) + except Exception: + self.logger.error( + f"Failed to tombstone orch {orchestrator.id} to sink, dropping" + ) # Update last message time after successful send self.last_message_time = time.time() else: diff --git a/matrix/agents/agent_utils.py b/matrix/agents/agent_utils.py index 693d87c..9568423 100644 --- a/matrix/agents/agent_utils.py +++ b/matrix/agents/agent_utils.py @@ -9,7 +9,6 @@ import json import logging import os -import random import re import time from collections import namedtuple @@ -180,69 +179,60 @@ class HistPair(NamedTuple): response: RayDict -# ==== Utility Functions ==== +REMOTE_CALL_TIMEOUT = 60.0 # seconds per attempt +REMOTE_CALL_RETRIES = 3 + + async def send_with_retry( - orchestrator: "Orchestrator", - role: str, - sink: ray.actor.ActorHandle, - local_cache: Dict[str, List[ray.actor.ActorHandle]], - log: logging.Logger, - timeout: float = 60.0, - max_retries: int = 3, -) -> Dict[str, List[ray.actor.ActorHandle]]: - """ - Send orchestrator to an agent with local cache and fault-tolerant retry. + method, + *args, + timeout: float = REMOTE_CALL_TIMEOUT, + max_retries: int = REMOTE_CALL_RETRIES, + logger: Optional[logging.Logger] = None, +): + """Call a Ray remote method with retry and timeout. + + Only retries on RayActorError (actor dead/restarting) and TimeoutError. + Other exceptions are raised immediately. Args: - orchestrator: The orchestrator state to send - role: The role name of the target agent - sink: The sink actor handle for registry lookups - local_cache: Local team cache dict (will be updated on refresh) - log: Logger instance for warnings - timeout: Timeout for actor acquisition + method: Ray actor method (e.g. handle.submit) + *args: Arguments to pass to method.remote(*args) + timeout: Timeout per attempt in seconds max_retries: Maximum retry attempts + logger: Logger for warnings Returns: - Updated local_cache dict + The result of the remote call Raises: RuntimeError: If all retries exhausted + Exception: Non-retryable exceptions are raised immediately """ - last_exception = None + last_exception: Optional[Exception] = None for attempt in range(max_retries): try: - if role == "_sink": - agent = sink - elif attempt == 0 and role in local_cache and local_cache[role]: - # First attempt: use local cache for speed - agent = random.choice(local_cache[role]) - else: - # Fallback: get from sink with force refresh and update local cache - agent = await sink.get_actor.remote(role, timeout, True) - local_cache = await sink.get_team_snapshot.remote() - - await agent.receive_message.remote(orchestrator) - return local_cache # Success + return await asyncio.wait_for(method.remote(*args), timeout=timeout) except ray.exceptions.RayActorError as e: last_exception = e - log.warning( - f"Actor {role} is dead (attempt {attempt + 1}/{max_retries}): {repr(e)}" - ) - # Clear local cache for this role to force refresh - local_cache.pop(role, None) + if logger: + logger.warning( + f"Actor dead (attempt {attempt + 1}/{max_retries}): {repr(e)}" + ) + await asyncio.sleep(10) continue except TimeoutError as e: - last_exception = e # type: ignore[assignment] - log.warning( - f"Timeout getting actor {role} (attempt {attempt + 1}/{max_retries}): {repr(e)}" - ) + last_exception = e + if logger: + logger.warning( + f"Timeout (attempt {attempt + 1}/{max_retries}): {repr(e)}" + ) continue except Exception: - # For other exceptions, don't retry + # Non-retryable, raise immediately raise - # All retries exhausted raise RuntimeError( - f"Failed to send to {role} after {max_retries} attempts: {last_exception}" + f"Remote call failed after {max_retries} attempts: {last_exception}" ) diff --git a/matrix/agents/config/agents/sink.yaml b/matrix/agents/config/agents/sink.yaml index fb55dab..f3f435e 100644 --- a/matrix/agents/config/agents/sink.yaml +++ b/matrix/agents/config/agents/sink.yaml @@ -1,9 +1,7 @@ _sink: _target_: matrix.agents.p2p_agents.Sink debug: ${debug} - dead_orchestrator_tracking: ${dead_orchestrator_tracking} - max_concurrent_tasks: ${max_concurrent_tasks} - + ray_resources: num_cpus: 2 # resources: diff --git a/matrix/agents/config/simulation.yaml b/matrix/agents/config/simulation.yaml index c4f6ab7..1cdad0b 100644 --- a/matrix/agents/config/simulation.yaml +++ b/matrix/agents/config/simulation.yaml @@ -2,15 +2,21 @@ matrix: cluster_id: null matrix_dir: null - + # Simulation framework settings max_concurrent_tasks: 100 rate_limit_enqueue: false + debug: false seed: 42 num_trial: 1 -dead_orchestrator_tracking: false + +# Dispatcher settings +dispatcher: + debug: ${debug} + ray_resources: + num_cpus: 1 # Default output settings output: diff --git a/matrix/agents/dispatcher.py b/matrix/agents/dispatcher.py new file mode 100644 index 0000000..e462067 --- /dev/null +++ b/matrix/agents/dispatcher.py @@ -0,0 +1,237 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import logging +import time +from typing import Dict, Optional + +import ray +from ray.util.metrics import Gauge + +from .agent_utils import send_with_retry, setup_logging +from .orchestrator import DeadOrchestrator, Orchestrator + + +class Dispatcher: + """ + Dispatcher actor (one per role) that acts as a message broker. + + Receives orchestrators via enqueue(), pushes them to agents via + receive_message() round-robin, and routes completed work to the next + Dispatcher or Sink. + + Dead-agent detection is handled by agent restarts: when an agent + restarts (max_restarts=-1), it calls agent_started() which tombstones + any previously checked-out orchestrators for that agent. + """ + + def __init__( + self, + role: str, + sink: ray.actor.ActorHandle, + namespace: str, + debug: bool = False, + ): + self.role = role + self.sink = sink + self.namespace = namespace + + self.incoming_queue: asyncio.Queue = asyncio.Queue() + # orch_id -> (agent_ray_name, orchestrator) + self.checked_out: Dict[str, tuple[str, Orchestrator]] = {} + + # Other role Dispatchers for forwarding (handles are stable — dispatchers don't restart) + self.dispatchers: Dict[str, ray.actor.ActorHandle] = {} + # Agent handles for push-based delivery, populated by set_agents() + self.agents: Dict[str, ray.actor.ActorHandle] = {} + self._agent_names: list[str] = [] + self._next_idx: int = 0 + + self.logger = logging.getLogger(f"Dispatcher[{role}]") + setup_logging(self.logger, debug) + + self.queue_size = Gauge( + "dispatcher_queue_size", + description="Current queue size for this dispatcher", + tag_keys=("role",), + ) + self.queue_size.set_default_tags({"role": role}) + + # Event loop waits for first agent to register via agent_started() + self._agents_ready = asyncio.Event() + self._loop_task = asyncio.get_event_loop().create_task(self._event_loop()) + + async def set_dispatchers(self, dispatchers: Dict[str, ray.actor.ActorHandle]): + """Wire this Dispatcher to other role Dispatchers for forwarding.""" + self.dispatchers = dispatchers + + async def enqueue(self, orchestrator: Orchestrator): + """Called by other Dispatchers or the framework to queue work.""" + orchestrator.enqueue_timestamp = time.time() + await self.incoming_queue.put(orchestrator) + self.queue_size.set(self.incoming_queue.qsize()) + + async def _event_loop(self): + """Dequeue orchestrators and push to agents round-robin.""" + await self._agents_ready.wait() + self.logger.debug("{self.role} run event loop") + while True: + orchestrator = await self.incoming_queue.get() + if orchestrator is None: # shutdown sentinel + break + self.queue_size.set(self.incoming_queue.qsize()) + await self._push_to_agent(orchestrator) + + async def _push_to_agent(self, orchestrator: Orchestrator): + """Push orchestrator to an agent with round-robin and retry.""" + if not self._agent_names: + self.logger.error(f"No agents registered for role {self.role}") + dead = DeadOrchestrator( + orchestrator.id, error=f"No agents for role {self.role}" + ) + await self._send_to_sink(dead) + return + + num_agents = len(self._agent_names) + for attempt in range(num_agents): + idx = (self._next_idx + attempt) % num_agents + agent_name = self._agent_names[idx] + agent_handle = self.agents[agent_name] + try: + self.checked_out[orchestrator.id] = (agent_name, orchestrator) + await send_with_retry( + agent_handle.receive_message, + orchestrator, + logger=self.logger, + ) + self._next_idx = (idx + 1) % num_agents + return + except Exception as e: + self.logger.warning( + f"Failed to push to agent {agent_name}: {type(e).__name__}: {e}", + exc_info=True, + ) + continue + + # All agents failed + self.logger.error( + f"All agents unreachable for orch {orchestrator.id}, tombstoning" + ) + self.checked_out.pop(orchestrator.id) + dead = DeadOrchestrator( + orchestrator.id, error=f"All agents for {self.role} unreachable" + ) + await self._send_to_sink(dead) + + async def submit(self, processed_orch: Orchestrator, orch_id: str): + """ + Agent acks completion. Dispatcher removes from checked_out, + checks is_done()/current_agent(), forwards to target Dispatcher or Sink. + """ + if orch_id not in self.checked_out: + # race condition, dispatch already marked it died + self.logger.warning( + f"Dispatcher {self.role} failed to find {orch_id}, ignore" + ) + return + + self.checked_out.pop(orch_id) + + if await processed_orch.is_done(): + await self._send_to_sink(processed_orch) + else: + next_role = processed_orch.current_agent() + if next_role == "_sink": + await self._send_to_sink(processed_orch) + elif next_role in self.dispatchers: + try: + await send_with_retry( + self.dispatchers[next_role].enqueue, + processed_orch, + logger=self.logger, + ) + except Exception as e: + self.logger.error( + f"Failed to forward orch {orch_id} to dispatcher {next_role}: {repr(e)}" + ) + dead = DeadOrchestrator( + orch_id, + error=f"Failed to forward to {next_role}: {e}", + ) + await self._send_to_sink(dead) + else: + self.logger.error( + f"Unknown target role '{next_role}' for orch {orch_id}, sending to sink" + ) + await self._send_to_sink(processed_orch) + + async def submit_error(self, orchestrator: Orchestrator, orch_id: str): + """Agent error ack. Forward directly to Sink.""" + self.checked_out.pop(orch_id) + await self._send_to_sink(orchestrator) + + async def agent_started( + self, agent_ray_name: str, agent_handle: ray.actor.ActorHandle + ): + """ + Called on agent (re)start. Updates the agent handle and tombstones + any previously checked-out orchestrators for this agent. + """ + # Register / update agent handle + self.logger.debug(f"Agent_started {agent_ray_name} {agent_handle}") + self.agents[agent_ray_name] = agent_handle + if agent_ray_name not in self._agent_names: + self._agent_names.append(agent_ray_name) + if not self._agents_ready.is_set(): + self._agents_ready.set() + + # Atomically remove all checked-out entries for this agent (no await + # between removals, so _push_to_agent cannot interleave here) + to_tombstone = [] + for oid in list(self.checked_out): + aname, orch = self.checked_out[oid] + if aname == agent_ray_name: + del self.checked_out[oid] + to_tombstone.append((oid, orch)) + + # Now send tombstones — _push_to_agent may interleave at these await + # points, but any new checked_out entries are for the new agent instance + for oid, orch in to_tombstone: + self.logger.warning( + f"Agent {agent_ray_name} restarted, tombstoning orch {oid}" + ) + dead = DeadOrchestrator( + oid, error=f"Agent {agent_ray_name} restarted while processing" + ) + await self._send_to_sink(dead) + + async def _send_to_sink(self, orchestrator: Orchestrator): + """Send to sink with retry. If sink is unreachable, log and drop.""" + try: + await send_with_retry( + self.sink.receive_message, orchestrator, logger=self.logger + ) + except Exception as e: + self.logger.error( + f"Failed to send orch {orchestrator.id} to sink after retries: {repr(e)}" + ) + + async def shutdown(self): + """Stop event loop and signal all agents to shut down.""" + # Stop the event loop + await self.incoming_queue.put(None) + # Signal each agent to shut down + for agent_name, agent_handle in self.agents.items(): + try: + await agent_handle.shutdown.remote() + except Exception as e: + self.logger.warning( + f"Error shutting down agent {agent_name}: {repr(e)}" + ) + + async def check_health(self): + return True diff --git a/matrix/agents/examples/tau2_bench.py b/matrix/agents/examples/tau2_bench.py index 907e5b8..4fd3383 100644 --- a/matrix/agents/examples/tau2_bench.py +++ b/matrix/agents/examples/tau2_bench.py @@ -181,6 +181,7 @@ def __init__( config: DictConfig, resources: dict[str, BaseResourceClient], sink=None, + **kwargs, ): super().__init__( id, @@ -188,6 +189,7 @@ def __init__( config, resources=resources, sink=sink, + **kwargs, ) if agent_id == "llm_agent": # instantiate the template @@ -357,6 +359,7 @@ def __init__( config: DictConfig, resources: dict[str, BaseResourceClient], sink=None, + **kwargs, ): super().__init__( id, @@ -364,6 +367,7 @@ def __init__( config, resources=resources, sink=sink, + **kwargs, ) self.tmp_dir = os.path.abspath(os.path.expanduser(config["tmp_dir"])) diff --git a/matrix/agents/p2p_agents.py b/matrix/agents/p2p_agents.py index 827b04d..f277805 100644 --- a/matrix/agents/p2p_agents.py +++ b/matrix/agents/p2p_agents.py @@ -33,10 +33,10 @@ HistPair, RayDict, get_ray_actor_class, - send_with_retry, setup_logging, ) from .dataset_loader import BaseDatasetLoader +from .dispatcher import Dispatcher from .orchestrator import ( BaseResourceClient, DeadOrchestrator, @@ -49,7 +49,6 @@ __all__ = [ "RayDict", "HistPair", - "send_with_retry", "Orchestrator", "SequentialOrchestrator", "DeadOrchestrator", @@ -60,6 +59,7 @@ "ContainerExecutionAgent", "LLMAgentActor", "Sink", + "Dispatcher", "ScalableTeamManager", "P2PAgentFramework", "main", @@ -88,15 +88,28 @@ def done(self): class ScalableTeamManager: - """Manages teams with multiple actors per role using load balancers when needed""" + """Manages teams with multiple actors per role using Dispatchers for routing and load balancing""" - def __init__(self, simulation_id: str): + def __init__( + self, simulation_id: str, dispatcher_config: Optional[DictConfig] = None + ): self.simulation_id = simulation_id self.teamConfig: Dict[str, Tuple[Type, DictConfig]] = {} - # Team registry config: role -> (count, namespace) for actor lookup - self.team_registry_config: Dict[str, Tuple[int, str]] = {} # Sink actor handle - must be created first self.sink: Optional[ray.actor.ActorHandle] = None + # Dispatcher per role (excluding sink) + self.dispatchers: Dict[str, ray.actor.ActorHandle] = {} + # Agent handles per role (must be kept alive to prevent GC) + self.agents: Dict[str, List[ray.actor.ActorHandle]] = {} + # Dispatcher ray resource config + self.dispatcher_ray_resources: dict[str, Any] = {} + self.dispatcher_debug: bool = False + if dispatcher_config: + if "ray_resources" in dispatcher_config: + self.dispatcher_ray_resources = OmegaConf.to_container( # type: ignore [assignment] + dispatcher_config["ray_resources"], resolve=True + ) + self.dispatcher_debug = dispatcher_config.get("debug", False) def create_role(self, role_name: str, agent_config: DictConfig, resources): """Create agents for a role. _sink must be created first.""" @@ -117,8 +130,27 @@ def create_role(self, role_name: str, agent_config: DictConfig, resources): # Sink should not restart; other actors restart infinitely max_restarts = 0 if is_sink else -1 + # Create Dispatcher for non-sink roles + dispatcher = None + if not is_sink: + DispatcherActor = ray.remote(Dispatcher) + dispatcher = DispatcherActor.options( + name=f"dispatcher_{role_name}", + namespace=self.simulation_id, + max_restarts=0, + **self.dispatcher_ray_resources, + ).remote( + role=role_name, + debug=self.dispatcher_debug, + sink=self.sink, + namespace=self.simulation_id, + ) + self.dispatchers[role_name] = dispatcher # type: ignore[assignment] + logger.info(f"Created dispatcher for role: {role_name}") + agents = [] for i in range(count): + ray_name = f"{role_name}_{i}" kwargs = { "id": f"{self.simulation_id}_{role_name}_{i}", "agent_id": role_name, @@ -127,9 +159,12 @@ def create_role(self, role_name: str, agent_config: DictConfig, resources): } if not is_sink: kwargs["sink"] = self.sink + kwargs["dispatcher_name"] = f"dispatcher_{role_name}" + kwargs["namespace"] = self.simulation_id + kwargs["ray_name"] = ray_name agent = agent_class.options( - name=f"{role_name}_{i}", + name=ray_name, namespace=self.simulation_id, max_restarts=max_restarts, **ray_resources, @@ -143,28 +178,30 @@ def create_role(self, role_name: str, agent_config: DictConfig, resources): if is_sink: self.sink = agents[0] + self.agents[role_name] = agents + self.teamConfig[role_name] = ( agent_class.__ray_metadata__.modified_class, agent_config, ) - # Only add non-sink roles to registry (sink won't restart) - if not is_sink: - self.team_registry_config[role_name] = (count, self.simulation_id) return agents async def initialize_team(self, team: Dict[str, List[ray.actor.ActorHandle]]): - """Initialize all agents with team references. + """Initialize all agents and wire Dispatchers. Args: team: Dict mapping role -> list of actor handles (from create_role return values) """ - # Use handles from create_role directly (avoid race with ray.get_actor) + # Health-check all actors and dispatchers all_actors = [self.sink] for role_handles in team.values(): all_actors.extend(role_handles) + all_actors.extend(self.dispatchers.values()) - logger.info(f"Checking Ray actor health for {len(all_actors)} actors") + logger.info( + f"Checking Ray actor health for {len(all_actors)} actors (including dispatchers)" + ) try: await asyncio.wait_for( asyncio.gather( @@ -179,18 +216,31 @@ async def initialize_team(self, team: Dict[str, List[ray.actor.ActorHandle]]): raise e logger.info("Checking Ray actor health done...") - # Initialize Sink's team registry with verified handles (avoid re-lookup race) - if self.sink is not None: - await self.sink.set_team_registry.remote(self.team_registry_config, team) + # Wire Dispatchers to each other + for role_name, dispatcher in self.dispatchers.items(): + await dispatcher.set_dispatchers.remote(self.dispatchers) + + logger.info(f"Dispatchers wired: {list(self.dispatchers.keys())}") def get_team_config(self): """Get team config dictionary for orchestrator routing""" return self.teamConfig async def shutdown(self): - """Shutdown all actors via sink's registry""" + """Shutdown Dispatchers (which send sentinels to agents), then Sink.""" + # Shutdown dispatchers first (sends sentinels to agents) + for role_name, dispatcher in self.dispatchers.items(): + try: + await dispatcher.shutdown.remote() + except Exception as e: + logger.warning(f"Error shutting down dispatcher {role_name}: {repr(e)}") + + # Then shutdown sink if self.sink: - await self.sink.shutdown_all.remote() + try: + await self.sink.shutdown.remote() + except Exception as e: + logger.warning(f"Error shutting down sink: {repr(e)}") class P2PAgentFramework: @@ -209,13 +259,12 @@ def __init__(self, sim_index: int, cfg: DictConfig): self.semaphore = asyncio.Semaphore(self.max_concurrent_tasks) self.sink: Sink = None # type: ignore[assignment] - self.team_manager = ScalableTeamManager(self.simulation_id) + self.team_manager = ScalableTeamManager( + self.simulation_id, + dispatcher_config=self.cfg.get("dispatcher"), + ) self.resources: Dict[str, BaseResourceClient] = {} - # Local team cache for latency-sensitive _process_item - # Updated from sink on actor failure - self._local_team_cache: Dict[str, List[ray.actor.ActorHandle]] = {} - random.seed(self.cfg["seed"]) self.num_trial = self.cfg["num_trial"] if self.num_trial > 1: @@ -256,9 +305,6 @@ async def create_team( await self.team_manager.initialize_team(team) self.sink = self.team_manager.sink # type: ignore[assignment] - # Initialize local team cache from sink - self._local_team_cache = await self.sink.get_team_snapshot.remote() # type: ignore[attr-defined] - async def _progress_task(self): async def _update_progress(): done = await self.sink.get_progress.remote() # type: ignore[attr-defined] @@ -342,22 +388,15 @@ async def _process_item(self, trial_item: Tuple[int, Dict[str, Any]]): logger.debug(f"done Init {orchestrator.id}") logger.debug(f"Enqueue: {orchestrator.id}") - # Register as in-flight before sending to first agent - if self.cfg.dead_orchestrator_tracking: - await self.sink.register_inflight.remote(orchestrator.id) # type: ignore[attr-defined] - - # Send to first agent with local cache for latency, fallback to sink on error + # Enqueue to the first agent's Dispatcher try: - self._local_team_cache = await send_with_retry( - orchestrator, - first_agent_role, - self.sink, # type: ignore[arg-type] - self._local_team_cache, - logger, + await self.team_manager.dispatchers[first_agent_role].enqueue.remote( + orchestrator + ) + except Exception as e: + logger.error( + f"Failed to enqueue to dispatcher for {first_agent_role}: {repr(e)}" ) - except RuntimeError as e: - # All retries exhausted - send to sink as error - logger.error(str(e)) orchestrator.status["error"] = f"Failed to reach {first_agent_role}: {e}" await self.sink.receive_message.remote(orchestrator) # type: ignore[attr-defined] return diff --git a/matrix/agents/sink.py b/matrix/agents/sink.py index d474345..14beedd 100644 --- a/matrix/agents/sink.py +++ b/matrix/agents/sink.py @@ -7,10 +7,9 @@ import asyncio import json import os -import random import time from functools import partial -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional import ray import zstandard as zstd @@ -33,15 +32,6 @@ # @ray.remote class Sink(AgentActor): - # Constants for team registry - REFRESH_INTERVAL = 30.0 # seconds between periodic refreshes - GET_ACTOR_TIMEOUT = 60.0 # default timeout for get_actor calls - GET_ACTOR_RETRY_INTERVAL = 1.0 # seconds between retries - - # dead detection - IDLE_TIMEOUT = 60.0 # seconds of idle time before checking for dead tasks - MAX_NEW_ZOMBIE = 10 # at most mark this number at once - LATE_ARRIVAL_INCR = 5 # make it harder to mark zombie for each late arrivals def __init__( self, @@ -56,30 +46,6 @@ def __init__( self.num_done = 0 self.num_inputs: Optional[int] = None self.ray_objects: dict[str, ray.ObjectRef] = {} # hold the ref to avoid gc - self.pending_writes: int = ( - 0 # Track in-progress writes to avoid closing file prematurely - ) - - # Team registry: Sink is the source of truth for actor handles - self._team: Dict[str, List[ray.actor.ActorHandle]] = {} - self._team_config: Dict[str, Tuple[int, str]] = {} # role -> (count, namespace) - self._team_lock = asyncio.Lock() - self._refresh_task: Optional[asyncio.Task] = None - - # Optimistic timeout tracking for dead orchestrator detection - self.dead_orchestrator_tracking = config.dead_orchestrator_tracking - self.dead_order_window = config.max_concurrent_tasks - # Registration order tracking: when task N completes, tasks before N-max_concurrent_tasks are zombies - self.registration_counter: int = 0 - self.inflight_orchestrators: dict[str, int] = {} # id -> registration_order - self.inflight_order: list[tuple[int, str]] = ( - [] - ) # sorted (order, id) for efficient oldest iteration, lazy update, have stale items - self.zombie_orchestrators: set[str] = set() # set of zombie orchestrator ids - - # Idle detection for dead task recovery - self.last_message_time: float = time.time() - self._idle_check_task: Optional[asyncio.Task] = None self.num_dead: int = 0 # Counter for dead/lost orchestrators additional_metrics_config: list[tuple[str, type, str, str, dict[str, Any]]] = [ @@ -130,12 +96,8 @@ async def set_metrics_output( async def set_num_inputs(self, num_inputs: int): self.num_inputs = num_inputs - # Start idle check task now that we know the total inputs - if self._idle_check_task is None: - self._idle_check_task = asyncio.create_task(self._idle_check_loop()) async def preprocess(self, orchestrator: "Orchestrator"): - # Update last message time for idle detection self.last_message_time = time.time() def _write_output(output_data, output_path): @@ -156,21 +118,17 @@ def _write_output(output_data, output_path): # Run CPU-intensive work in thread pool start_time = time.perf_counter() loop = asyncio.get_event_loop() - self.pending_writes += 1 # Track pending write before yielding - try: - data_to_write = await loop.run_in_executor( - None, - partial( - _write_output, await orchestrator.to_output(), self.output_path - ), - ) - self.output_file.write(data_to_write) - self.sink_write_latency.set(time.perf_counter() - start_time) # type: ignore[attr-defined] - finally: - self.pending_writes -= 1 # Always decrement, even on error + data_to_write = await loop.run_in_executor( + None, + partial( + _write_output, await orchestrator.to_output(), self.output_path + ), + ) + self.output_file.write(data_to_write) + self.sink_write_latency.set(time.perf_counter() - start_time) # type: ignore[attr-defined] - # Get registration order and remove from inflight tracking - completed_order = self.inflight_orchestrators.pop(orchestrator.id, None) + # Increment num_done for ALL arrivals (normal + dead) + self.num_done += 1 if is_tombstone: self.num_dead += 1 @@ -179,77 +137,18 @@ def _write_output(output_data, output_path): latency = orchestrator.init_timestamp - orchestrator.creation_timestamp self.task_init_latency.set(latency) # type: ignore[attr-defined] - # Check if this orchestrator was in zombie set (came back from the dead) - is_zombie_return = orchestrator.id in self.zombie_orchestrators - if is_zombie_return: - # Remove from zombie set - it actually completed - self.zombie_orchestrators.discard(orchestrator.id) - self.logger.info( - f"Zombie orchestrator {orchestrator.id} returned, removing from zombie set" - ) - self.dead_order_window += self.LATE_ARRIVAL_INCR - - if not is_zombie_return: - self.num_done += 1 - - # Position-based zombie detection: when a non-error orchestrator completes, - # all tasks registered more than max_concurrent_tasks before it should be done - if ( - self.dead_orchestrator_tracking - and not orchestrator.is_error() - and completed_order is not None - ): - threshold = completed_order - self.dead_order_window - if threshold > 0: - to_zombify = 0 - # Iterate through oldest entries first, using lazy deletion - # (skip entries already removed from inflight_orchestrators) - while ( - self.inflight_order - and self.inflight_order[0][0] <= threshold - and to_zombify < self.MAX_NEW_ZOMBIE - ): - order, orch_id = self.inflight_order.pop(0) - # Only zombify if still in inflight (not already completed) - if orch_id in self.inflight_orchestrators: - del self.inflight_orchestrators[orch_id] - self.zombie_orchestrators.add(orch_id) - to_zombify += 1 - # Increase num_done to release semaphore - self.num_done += 1 - - if to_zombify: - self.logger.info( - f"Moved {to_zombify} orchestrators to zombie (threshold order: {threshold})" - ) - if self.metrics_accumulator: self.metrics_accumulator.accumulate(orchestrator) - # Don't close output file if there are zombies or pending writes - if ( - self.num_inputs is not None - and self.num_done >= self.num_inputs - and not self.zombie_orchestrators - and self.pending_writes == 0 - ): + # Close output file when all tasks are done and no pending writes + if self.num_inputs is not None and self.num_done >= self.num_inputs: self.output_file.close() return {"orchestrator": orchestrator} async def get_progress(self) -> int: - # Cap num_done when there are zombies to prevent input pipeline from finishing early - if self.zombie_orchestrators and self.num_inputs is not None: - return min(self.num_done, self.num_inputs - 1) return self.num_done - async def register_inflight(self, orchestrator_id: str): - """Register an orchestrator as in-flight (before sending to first agent).""" - self.registration_counter += 1 - order = self.registration_counter - self.inflight_orchestrators[orchestrator_id] = order - self.inflight_order.append((order, orchestrator_id)) - async def get_overall_metrics(self) -> dict[str, Any] | None: return ( self.metrics_accumulator.done() @@ -265,35 +164,9 @@ async def check_health(self): return True async def shutdown(self): - """Gracefully shutdown the Sink agent and cancel background tasks.""" - if self._refresh_task is not None: - self._refresh_task.cancel() - try: - await self._refresh_task - except asyncio.CancelledError: - pass - self._refresh_task = None - if self._idle_check_task is not None: - self._idle_check_task.cancel() - try: - await self._idle_check_task - except asyncio.CancelledError: - pass - self._idle_check_task = None + """Gracefully shutdown the Sink agent.""" await super().shutdown() - async def shutdown_all(self): - """Shutdown all actors in the team registry, then shutdown self.""" - async with self._team_lock: - for role, actors in self._team.items(): - for actor in actors: - try: - await actor.shutdown.remote() - except Exception as e: - self.logger.warning(f"Error shutting down {role}: {repr(e)}") - # Finally shutdown self - await self.shutdown() - async def register_object(self, obj: list[ray.ObjectRef]): o = obj[0] self.ray_objects[o.hex()] = o # type: ignore[attr-defined] @@ -301,216 +174,3 @@ async def register_object(self, obj: list[ray.ObjectRef]): async def unregister_object(self, obj: list[ray.ObjectRef]): for o in obj: self.ray_objects.pop(o.hex(), None) # type: ignore[attr-defined] - - # ==== Team Registry Methods ==== - async def set_team_registry( - self, - team_config: Dict[str, Tuple[int, str]], - initial_team: Optional[Dict[str, List[ray.actor.ActorHandle]]] = None, - ): - """ - Initialize the team registry in Sink. - - Args: - team_config: Dict mapping role -> (count, namespace) for actor lookup - initial_team: Optional pre-verified actor handles to avoid initial lookup race - """ - async with self._team_lock: - self._team_config = team_config - self._team = initial_team if initial_team is not None else {} - - # Only do initial refresh if we didn't get pre-verified handles - if initial_team is None: - await self._refresh_team_internal() - - # Start periodic refresh task - if self._refresh_task is None: - self._refresh_task = asyncio.create_task(self._periodic_refresh()) - self.logger.info( - f"Team registry initialized with roles: {list(team_config.keys())}" - ) - - async def _periodic_refresh(self): - """Background task to periodically refresh actor handles.""" - while self.running: - try: - await asyncio.sleep(self.REFRESH_INTERVAL) - await self._refresh_team_internal() - except asyncio.CancelledError: - break - except Exception as e: - self.logger.warning(f"Error in periodic team refresh: {repr(e)}") - - async def _refresh_team_internal(self): - """Internal method to refresh all actor handles from Ray.""" - async with self._team_lock: - for role, (count, namespace) in self._team_config.items(): - new_handles = [] - for i in range(count): - actor_name = f"{role}_{i}" - try: - handle = ray.get_actor(actor_name, namespace=namespace) - new_handles.append(handle) - except ValueError: - # Actor not found (not yet restarted) - self.logger.warning( - f"Actor {actor_name} not found in namespace {namespace}" - ) - if new_handles: - self._team[role] = new_handles - self.logger.debug(f"Team refreshed: {list(self._team.keys())}") - - async def force_refresh(self): - """Force an immediate refresh of all actor handles.""" - await self._refresh_team_internal() - - async def _write_tombstone(self, orchestrator_id: str): - """ - Write a tombstone record for a lost orchestrator via normal Sink flow. - - Args: - orchestrator_id: The ID of the lost orchestrator - """ - self.logger.debug( - f"Creating tombstone for lost orchestrator: {orchestrator_id}" - ) - - # Create a DeadOrchestrator and process it through normal flow - dead_orch = DeadOrchestrator(orchestrator_id) - await self.receive_message(dead_orch) - - async def _idle_check_loop(self): - """Background task to detect idle state and write tombstones for lost tasks.""" - while self.running: - try: - await asyncio.sleep(10.0) # Check every 10 seconds - - # Skip if we're done and no zombies - has_zombies = bool(self.zombie_orchestrators) - if ( - self.num_inputs is not None - and self.num_done >= self.num_inputs - and not has_zombies - ): - break - - # Check if we've been idle long enough - idle_time = time.time() - self.last_message_time - if idle_time < self.IDLE_TIMEOUT: - continue - - # Check if all actors have empty queues and no pending tasks - all_idle = await self._check_all_actors_idle() - if not all_idle: - continue - - for orch_id in self.inflight_orchestrators: - self.zombie_orchestrators.add(orch_id) - self.num_done += 1 - self.inflight_orchestrators.clear() - - # All actors are idle - confirm zombies are dead and write tombstones - if self.zombie_orchestrators: - self.logger.warning( - f"System idle for {idle_time:.1f}s with {len(self.zombie_orchestrators)} " - f"zombie tasks. Writing tombstones to confirm dead." - ) - tasks = [ - self._write_tombstone(orch_id) - for orch_id in self.zombie_orchestrators - ] - self.zombie_orchestrators.clear() - await asyncio.gather(*tasks) - - # Wait for tombstones to be processed through the queue - while self.queue.qsize() > 0 or len(self.pending_tasks) > 0: - await asyncio.sleep(1) - # File will be closed by preprocess when last tombstone is written - break - - # Check if we're still short of expected tasks - if self.num_inputs is not None and self.num_done < self.num_inputs: - self.logger.error( - f"Completed {self.num_done} tasks but expected {self.num_inputs}. " - f"Some tasks may have been lost before registration." - ) - break - - except asyncio.CancelledError: - break - except Exception as e: - self.logger.warning(f"Error in idle check loop: {repr(e)}") - - async def _check_all_actors_idle(self) -> bool: - """Check if all actors have zero queue + pending tasks and idle timestamps.""" - now = time.time() - async with self._team_lock: - all_actors = [actor for actors in self._team.values() for actor in actors] - - for actor in all_actors: - try: - count, last_msg_time = await asyncio.wait_for( - actor.get_active_count.remote(), - timeout=5.0, - ) - if count > 0: - return False - # Also check if actor's last message time is within IDLE_TIMEOUT - if now - last_msg_time < self.IDLE_TIMEOUT: - return False - except Exception as e: - # If we can't reach an actor, it might be dead - don't consider system idle - self.logger.debug(f"Failed to get active count from actor: {repr(e)}") - return False - - return True - - async def get_actor( - self, role: str, timeout: Optional[float] = None, force_refresh: bool = False - ) -> ray.actor.ActorHandle: - """ - Get a random actor handle for the given role. - Blocks until an actor is available or timeout expires. - - Args: - role: The role name to get an actor for - timeout: Maximum time to wait in seconds (default: GET_ACTOR_TIMEOUT) - force_refresh: If True, refresh handles before returning - - Returns: - A random actor handle for the role - - Raises: - TimeoutError: If no actor is available within timeout - KeyError: If role is not registered - """ - if timeout is None: - timeout = self.GET_ACTOR_TIMEOUT - - if force_refresh: - await self._refresh_team_internal() - - start_time = time.time() - while True: - async with self._team_lock: - if role in self._team and self._team[role]: - return random.choice(self._team[role]) - - # Check timeout - elapsed = time.time() - start_time - if elapsed >= timeout: - raise TimeoutError( - f"Timeout waiting for actor with role '{role}' after {timeout}s" - ) - - # Wait and retry - self.logger.debug( - f"No actor available for role '{role}', retrying in {self.GET_ACTOR_RETRY_INTERVAL}s" - ) - await asyncio.sleep(self.GET_ACTOR_RETRY_INTERVAL) - await self._refresh_team_internal() - - async def get_team_snapshot(self) -> Dict[str, List[ray.actor.ActorHandle]]: - """Get a snapshot of the current team map.""" - async with self._team_lock: - return dict(self._team) diff --git a/matrix/scripts/kill_ray_actor.py b/matrix/scripts/kill_ray_actor.py index 7b7554f..796fa98 100644 --- a/matrix/scripts/kill_ray_actor.py +++ b/matrix/scripts/kill_ray_actor.py @@ -153,9 +153,25 @@ def kill_random_repeatedly( print() kill_count = 0 + # Pre-determine the kill order: shuffle all actors, then cycle through + if max_kills is not None: + kill_order = actor_names * ((max_kills // len(actor_names)) + 1) + random.shuffle(kill_order) + kill_order = kill_order[:max_kills] + else: + kill_order = None + + idx = 0 while max_kills is None or kill_count < max_kills: - # Pick a random actor - target = random.choice(actor_names) + if kill_order is not None: + target = kill_order[idx] + else: + # For unlimited mode, shuffle a full round then repeat + if idx % len(actor_names) == 0: + current_round = list(actor_names) + random.shuffle(current_round) + target = current_round[idx % len(actor_names)] + idx += 1 print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] Attempting to kill: {target}") # Kill remotely