diff --git a/daisy/task_worker_pools.py b/daisy/task_worker_pools.py index 49e41fe8..e37b90a7 100644 --- a/daisy/task_worker_pools.py +++ b/daisy/task_worker_pools.py @@ -39,8 +39,14 @@ def stop(self): worker_pool.stop() def check_worker_health(self): - for worker_pool in self.worker_pools.values(): + for task_id, worker_pool in self.worker_pools.items(): worker_pool.check_for_errors() + reaped = worker_pool.reap_dead_workers() + if reaped > 0: + logger.warning( + "Replacing %d dead workers for task %s", reaped, task_id + ) + worker_pool.inc_num_workers(reaped) def on_block_failure(self, block, exception, context): task_id = context["task_id"] diff --git a/daisy/worker_pool.py b/daisy/worker_pool.py index e6078fd2..ee65b8ba 100644 --- a/daisy/worker_pool.py +++ b/daisy/worker_pool.py @@ -102,6 +102,28 @@ def check_for_errors(self): except queue.Empty: pass + def reap_dead_workers(self): + """Detect worker processes that have exited and remove them from the + pool. Returns the number of workers reaped.""" + + dead_worker_ids = [] + with self.workers_lock: + for worker_id, worker in self.workers.items(): + if worker.process is not None and not worker.process.is_alive(): + logger.warning( + "Worker %s (pid %d) exited with code %d", + worker, + worker.process.pid, + worker.process.exitcode, + ) + worker.process = None + dead_worker_ids.append(worker_id) + + for worker_id in dead_worker_ids: + del self.workers[worker_id] + + return len(dead_worker_ids) + def _start_workers(self, n): logger.debug("starting %d new workers", n) new_workers = [ diff --git a/tests/process_block_or_die.py b/tests/process_block_or_die.py new file mode 100644 index 00000000..49dd7f70 --- /dev/null +++ b/tests/process_block_or_die.py @@ -0,0 +1,29 @@ +"""Worker script that crashes on first invocation, works on subsequent ones. + +Uses a marker file to track whether a crash has already occurred. The first +worker to run creates the marker and exits via SystemExit (which bypasses +the normal exception handling in daisy's Worker._spawn_wrapper). Subsequent +workers see the marker and process blocks normally. +""" + +import daisy + +import os +import sys + +tmp_path = sys.argv[1] +marker = os.path.join(tmp_path, "worker_crashed") + +if not os.path.exists(marker): + # First worker: create marker and crash + with open(marker, "w") as f: + f.write("crashed") + raise SystemExit(1) + +# Subsequent workers: process blocks normally +client = daisy.Client() + +while True: + with client.acquire_block() as block: + if block is None: + break diff --git a/tests/test_dead_workers.py b/tests/test_dead_workers.py new file mode 100644 index 00000000..392cacff --- /dev/null +++ b/tests/test_dead_workers.py @@ -0,0 +1,55 @@ +"""Test that the server detects and replaces dead worker processes. + +Workers can die silently (e.g., SIGKILL/OOM, SystemExit) without queuing an +error. Without dead worker detection, the server would hang forever waiting +for messages from workers that no longer exist. +""" + +import daisy +from daisy.logging import set_log_basedir + +import logging +import os +import subprocess +import sys + +logging.basicConfig(level=logging.DEBUG) + + +def test_dead_worker_replacement(tmp_path): + """Workers that exit via SystemExit are detected and replaced. + + The first batch of workers raises SystemExit (simulating an OOM kill or + similar unrecoverable crash that bypasses normal exception handling). + The dead worker detection logic replaces them, and the replacement + workers complete the task successfully. + """ + set_log_basedir(tmp_path) + + def start_worker(): + subprocess.run( + [sys.executable, "tests/process_block_or_die.py", str(tmp_path)] + ) + + task = daisy.Task( + "test_dead_worker_task", + total_roi=daisy.Roi((0,), (10,)), + read_roi=daisy.Roi((0,), (10,)), + write_roi=daisy.Roi((0,), (10,)), + process_function=start_worker, + check_function=None, + read_write_conflict=False, + fit="valid", + num_workers=1, + max_retries=2, + timeout=None, + ) + + server = daisy.Server() + task_states = server.run_blockwise([task]) + assert task_states[task.task_id].is_done(), task_states[task.task_id] + + # Verify the crash marker exists (first worker did crash) + assert os.path.exists(tmp_path / "worker_crashed"), ( + "Expected first worker to crash and leave a marker file" + )