Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 88 additions & 29 deletions matrix/agents/agent_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
from ray.util.metrics import Counter, Gauge

from .agent_utils import send_with_retry, setup_logging
from .orchestrator import BaseResourceClient, Orchestrator

logger = logging.getLogger(__name__)
from .orchestrator import BaseResourceClient, DeadOrchestrator, Orchestrator


# ==== Abstract AgentActor ====
Expand All @@ -34,6 +32,9 @@ def __init__(
config: DictConfig,
resources: dict[str, BaseResourceClient],
sink: ray.actor.ActorHandle,
dispatcher_name: str = None,
namespace: str = None,
ray_name: str = None,
):
# PATCH FIRST - before any HTTP clients are created
self._patch_getproxies()
Expand All @@ -43,7 +44,6 @@ def __init__(
self.config = config
self.resource_name = config.get("resource_name")
if self.resource_name:
logger.debug(f"Resources {list(resources.keys())}")
self.resource_client: Optional[BaseResourceClient] = resources[
self.resource_name
]
Expand All @@ -65,8 +65,13 @@ def __init__(
# Store sink reference and start event loop
# For Sink actor itself, sink will be None (set later via _set_self_as_sink)
self.sink = sink
# Local team cache for fast actor lookup, updated from sink on failure
self._local_team_cache: Dict[str, List[ray.actor.ActorHandle]] = {}
# Dispatcher name + namespace for name-based resolution (None for Sink)
self.dispatcher_name = dispatcher_name
self.namespace = namespace
self.ray_name = (
ray_name # This agent's Ray actor name (for dispatcher identification)
)
self.dispatcher = None # resolved lazily in _event_loop

self.event_loop_task: Optional[asyncio.Task] = (
asyncio.get_event_loop().create_task(self._event_loop())
Expand All @@ -81,13 +86,6 @@ def __init__(
"Total number of messages processed by this agent",
{},
),
(
"queue_size",
Gauge,
"agent_queue_size",
"Current queue size for this agent",
{},
),
(
"messages_received",
Counter,
Expand Down Expand Up @@ -145,6 +143,13 @@ def __init__(
"num of kb for the serialized message object",
{},
),
(
"queue_size",
Gauge,
"agent_queue_size",
"Current queue size for this agent",
{},
),
]
self._init_metrics(metrics_config)

Expand Down Expand Up @@ -219,7 +224,29 @@ async def _handle_task_exception(orchestrator, msg):
orchestrator._append(
self.agent_id, {"status_ok": False, "error": msg}, self.sink
)
await self.sink.receive_message.remote(orchestrator)
if self.dispatcher is not None:
try:
await send_with_retry(
self.dispatcher.submit_error,
orchestrator,
orchestrator.id,
logger=self.logger,
)
except Exception as e:
self.logger.error(
f"Failed to submit error to dispatcher for {orchestrator.id}: {repr(e)}"
)
# Fall back to sending directly to sink
try:
await send_with_retry(
self.sink.receive_message, orchestrator, logger=self.logger
)
except Exception:
self.logger.error(
f"Failed to send error orch {orchestrator.id} to sink, dropping"
)
else:
await self.receive_message(orchestrator)

def _log_exceptions(task):
try:
Expand All @@ -237,6 +264,28 @@ def _log_exceptions(task):
self.tasks_completed.inc() # type: ignore[attr-defined]
self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined]

# Resolve dispatcher for submit routing (if this agent has one)
if self.dispatcher_name is not None:
try:
self.dispatcher = ray.get_actor(
self.dispatcher_name, namespace=self.namespace
)
assert self.dispatcher is not None
except Exception as e:
self.logger.error(
f"Agent {self.id} failed to find dispatcher {self.dispatcher_name}: {repr(e)}"
)
return
try:
self_handle = ray.get_runtime_context().current_actor
await self.dispatcher.agent_started.remote(self.ray_name, self_handle)
except Exception as e:
self.logger.error(
f"Agent {self.id} failed to call agent_started on dispatcher: {repr(e)}"
)
return

# All agents: read from local queue (dispatcher pushes here, or direct for Sink)
while self.running:
orchestrator = await self.queue.get()
if orchestrator is None: # Shutdown sentinel
Expand All @@ -248,18 +297,15 @@ def _log_exceptions(task):
self.dequeue_latency, self.agent_id, latency # type: ignore[attr-defined]
)

# Update queue size after getting message
self.queue_size.set(self.queue.qsize()) # type: ignore[attr-defined]

task = asyncio.create_task(self._handle(orchestrator))
# Attach orchestrator to task for error logging
task._orchestrator = orchestrator # type: ignore[attr-defined]

self.pending_tasks.add(task)
self.tasks_started.inc() # type: ignore[attr-defined]
self.pending_tasks_count.set(len(self.pending_tasks)) # type: ignore[attr-defined]

# Clean up completed tasks
task.add_done_callback(self.pending_tasks.discard)
task.add_done_callback(_log_exceptions)

Expand All @@ -275,10 +321,6 @@ async def _handle(self, orchestrator: Orchestrator):
result = await self.postprocess(orchestrator, result)
if self.agent_id != "_sink":
next_state = await orchestrator.update(result, self, self.logger)
if await next_state.is_done():
next_agent_name = "_sink"
else:
next_agent_name = next_state.current_agent()

# temporary
if self.ENABLE_INSTRUMENTATION:
Expand All @@ -293,14 +335,31 @@ async def _handle(self, orchestrator: Orchestrator):
self.handle_latency, self.agent_id, latency # type: ignore[attr-defined]
)

# Send to next agent with fault-tolerant retry
self._local_team_cache = await send_with_retry(
next_state,
next_agent_name,
self.sink,
self._local_team_cache,
self.logger,
)
if self.dispatcher is not None:
# Submit to dispatcher for deterministic routing
try:
await send_with_retry(
self.dispatcher.submit,
next_state,
orchestrator.id,
logger=self.logger,
)
except Exception as e:
self.logger.error(
f"Failed to submit orch {orchestrator.id} to dispatcher: {repr(e)}"
)
dead = DeadOrchestrator(
orchestrator.id,
error=f"Dispatcher unreachable: {e}",
)
try:
await send_with_retry(
self.sink.receive_message, dead, logger=self.logger
)
except Exception:
self.logger.error(
f"Failed to tombstone orch {orchestrator.id} to sink, dropping"
)
# Update last message time after successful send
self.last_message_time = time.time()
else:
Expand Down
78 changes: 34 additions & 44 deletions matrix/agents/agent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import json
import logging
import os
import random
import re
import time
from collections import namedtuple
Expand Down Expand Up @@ -180,69 +179,60 @@ class HistPair(NamedTuple):
response: RayDict


# ==== Utility Functions ====
REMOTE_CALL_TIMEOUT = 60.0 # seconds per attempt
REMOTE_CALL_RETRIES = 3


async def send_with_retry(
orchestrator: "Orchestrator",
role: str,
sink: ray.actor.ActorHandle,
local_cache: Dict[str, List[ray.actor.ActorHandle]],
log: logging.Logger,
timeout: float = 60.0,
max_retries: int = 3,
) -> Dict[str, List[ray.actor.ActorHandle]]:
"""
Send orchestrator to an agent with local cache and fault-tolerant retry.
method,
*args,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be confusing, shall we explicitly define params, looks like it is ochestrator most likely

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

submit, enqueue, receive_message has different parameters

timeout: float = REMOTE_CALL_TIMEOUT,
max_retries: int = REMOTE_CALL_RETRIES,
logger: Optional[logging.Logger] = None,
):
"""Call a Ray remote method with retry and timeout.

Only retries on RayActorError (actor dead/restarting) and TimeoutError.
Other exceptions are raised immediately.

Args:
orchestrator: The orchestrator state to send
role: The role name of the target agent
sink: The sink actor handle for registry lookups
local_cache: Local team cache dict (will be updated on refresh)
log: Logger instance for warnings
timeout: Timeout for actor acquisition
method: Ray actor method (e.g. handle.submit)
*args: Arguments to pass to method.remote(*args)
timeout: Timeout per attempt in seconds
max_retries: Maximum retry attempts
logger: Logger for warnings

Returns:
Updated local_cache dict
The result of the remote call

Raises:
RuntimeError: If all retries exhausted
Exception: Non-retryable exceptions are raised immediately
"""
last_exception = None
last_exception: Optional[Exception] = None

for attempt in range(max_retries):
try:
if role == "_sink":
agent = sink
elif attempt == 0 and role in local_cache and local_cache[role]:
# First attempt: use local cache for speed
agent = random.choice(local_cache[role])
else:
# Fallback: get from sink with force refresh and update local cache
agent = await sink.get_actor.remote(role, timeout, True)
local_cache = await sink.get_team_snapshot.remote()

await agent.receive_message.remote(orchestrator)
return local_cache # Success
return await asyncio.wait_for(method.remote(*args), timeout=timeout)
except ray.exceptions.RayActorError as e:
last_exception = e
log.warning(
f"Actor {role} is dead (attempt {attempt + 1}/{max_retries}): {repr(e)}"
)
# Clear local cache for this role to force refresh
local_cache.pop(role, None)
if logger:
logger.warning(
f"Actor dead (attempt {attempt + 1}/{max_retries}): {repr(e)}"
)
await asyncio.sleep(10)
continue
except TimeoutError as e:
last_exception = e # type: ignore[assignment]
log.warning(
f"Timeout getting actor {role} (attempt {attempt + 1}/{max_retries}): {repr(e)}"
)
last_exception = e
if logger:
logger.warning(
f"Timeout (attempt {attempt + 1}/{max_retries}): {repr(e)}"
)
continue
except Exception:
# For other exceptions, don't retry
# Non-retryable, raise immediately
raise

# All retries exhausted
raise RuntimeError(
f"Failed to send to {role} after {max_retries} attempts: {last_exception}"
f"Remote call failed after {max_retries} attempts: {last_exception}"
)
4 changes: 1 addition & 3 deletions matrix/agents/config/agents/sink.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
_sink:
_target_: matrix.agents.p2p_agents.Sink
debug: ${debug}
dead_orchestrator_tracking: ${dead_orchestrator_tracking}
max_concurrent_tasks: ${max_concurrent_tasks}


ray_resources:
num_cpus: 2
# resources:
Expand Down
10 changes: 8 additions & 2 deletions matrix/agents/config/simulation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
matrix:
cluster_id: null
matrix_dir: null

# Simulation framework settings
max_concurrent_tasks: 100
rate_limit_enqueue: false


debug: false
seed: 42
num_trial: 1
dead_orchestrator_tracking: false

# Dispatcher settings
dispatcher:
debug: ${debug}
ray_resources:
num_cpus: 1

# Default output settings
output:
Expand Down
Loading
Loading