Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion daisy/task_worker_pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,14 @@ def stop(self):
worker_pool.stop()

def check_worker_health(self):
for worker_pool in self.worker_pools.values():
for task_id, worker_pool in self.worker_pools.items():
worker_pool.check_for_errors()
reaped = worker_pool.reap_dead_workers()
if reaped > 0:
logger.warning(
"Replacing %d dead workers for task %s", reaped, task_id
)
worker_pool.inc_num_workers(reaped)

def on_block_failure(self, block, exception, context):
task_id = context["task_id"]
Expand Down
22 changes: 22 additions & 0 deletions daisy/worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,28 @@ def check_for_errors(self):
except queue.Empty:
pass

def reap_dead_workers(self):
"""Detect worker processes that have exited and remove them from the
pool. Returns the number of workers reaped."""

dead_worker_ids = []
with self.workers_lock:
for worker_id, worker in self.workers.items():
if worker.process is not None and not worker.process.is_alive():
logger.warning(
"Worker %s (pid %d) exited with code %d",
worker,
worker.process.pid,
worker.process.exitcode,
)
worker.process = None
dead_worker_ids.append(worker_id)

for worker_id in dead_worker_ids:
del self.workers[worker_id]

return len(dead_worker_ids)

def _start_workers(self, n):
logger.debug("starting %d new workers", n)
new_workers = [
Expand Down
29 changes: 29 additions & 0 deletions tests/process_block_or_die.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Worker script that crashes on first invocation, works on subsequent ones.

Uses a marker file to track whether a crash has already occurred. The first
worker to run creates the marker and exits via SystemExit (which bypasses
the normal exception handling in daisy's Worker._spawn_wrapper). Subsequent
workers see the marker and process blocks normally.
"""

import daisy

import os
import sys

tmp_path = sys.argv[1]
marker = os.path.join(tmp_path, "worker_crashed")

if not os.path.exists(marker):
# First worker: create marker and crash
with open(marker, "w") as f:
f.write("crashed")
raise SystemExit(1)

# Subsequent workers: process blocks normally
client = daisy.Client()

while True:
with client.acquire_block() as block:
if block is None:
break
55 changes: 55 additions & 0 deletions tests/test_dead_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Test that the server detects and replaces dead worker processes.

Workers can die silently (e.g., SIGKILL/OOM, SystemExit) without queuing an
error. Without dead worker detection, the server would hang forever waiting
for messages from workers that no longer exist.
"""

import daisy
from daisy.logging import set_log_basedir

import logging
import os
import subprocess
import sys

logging.basicConfig(level=logging.DEBUG)


def test_dead_worker_replacement(tmp_path):
"""Workers that exit via SystemExit are detected and replaced.

The first batch of workers raises SystemExit (simulating an OOM kill or
similar unrecoverable crash that bypasses normal exception handling).
The dead worker detection logic replaces them, and the replacement
workers complete the task successfully.
"""
set_log_basedir(tmp_path)

def start_worker():
subprocess.run(
[sys.executable, "tests/process_block_or_die.py", str(tmp_path)]
)

task = daisy.Task(
"test_dead_worker_task",
total_roi=daisy.Roi((0,), (10,)),
read_roi=daisy.Roi((0,), (10,)),
write_roi=daisy.Roi((0,), (10,)),
process_function=start_worker,
check_function=None,
read_write_conflict=False,
fit="valid",
num_workers=1,
max_retries=2,
timeout=None,
)

server = daisy.Server()
task_states = server.run_blockwise([task])
assert task_states[task.task_id].is_done(), task_states[task.task_id]

# Verify the crash marker exists (first worker did crash)
assert os.path.exists(tmp_path / "worker_crashed"), (
"Expected first worker to crash and leave a marker file"
)
Loading