diff --git a/tests/unit/geometry/test_bbox_strider.py b/tests/unit/geometry/test_bbox_strider.py index 8fba1e7ac..ce7b5773f 100644 --- a/tests/unit/geometry/test_bbox_strider.py +++ b/tests/unit/geometry/test_bbox_strider.py @@ -2,7 +2,6 @@ import pytest -from zetta_utils import MULTIPROCESSING_NUM_TASKS_THRESHOLD from zetta_utils.geometry import BBox3D, BBoxStrider, IntVec3D, Vec3D @@ -56,11 +55,11 @@ def test_bbox_strider_get_all_chunks(mocker): ] -def test_bbox_strider_get_all_chunks_parallel(mocker): +def test_bbox_strider_get_all_chunks_large(mocker): strider = BBoxStrider( bbox=BBox3D.from_coords( start_coord=Vec3D(0, 0, 0), - end_coord=Vec3D(2, 1, MULTIPROCESSING_NUM_TASKS_THRESHOLD + 1), + end_coord=Vec3D(2, 1, 130), resolution=Vec3D(1, 1, 1), ), chunk_size=IntVec3D(1, 1, 1), diff --git a/tests/unit/mazepa/test_task.py b/tests/unit/mazepa/test_task.py index 714f937e8..dbbe47e74 100644 --- a/tests/unit/mazepa/test_task.py +++ b/tests/unit/mazepa/test_task.py @@ -16,6 +16,15 @@ from ..helpers import DummyException +# Module-level functions for spawn-compatible pebble process execution +def _slow_task_fn(): + time.sleep(0.3) + + +def _raising_task_fn(): + raise DummyException() + + def test_make_taskable_operation_cls() -> None: @taskable_operation_cls(operation_name="OpDummyClass1") @attrs.mutable @@ -56,24 +65,18 @@ def dummy_task_fn(): def test_task_runtime_limit() -> None: - @taskable_operation(runtime_limit_sec=0.1) - def dummy_task_fn(): - time.sleep(0.3) - - assert isinstance(dummy_task_fn, TaskableOperation) - task = dummy_task_fn.make_task() + slow_task_op = taskable_operation(runtime_limit_sec=0.1)(_slow_task_fn) + assert isinstance(slow_task_op, TaskableOperation) + task = slow_task_op.make_task() assert isinstance(task, Task) outcome = task(debug=False) assert isinstance(outcome.exception, MazepaTimeoutError) def test_task_no_handle_exc() -> None: - @taskable_operation(runtime_limit_sec=0.1) - def dummy_task_fn(): - raise DummyException() - - assert isinstance(dummy_task_fn, TaskableOperation) - task = dummy_task_fn.make_task() + raising_task_op = taskable_operation(runtime_limit_sec=10)(_raising_task_fn) + assert isinstance(raising_task_op, TaskableOperation) + task = raising_task_op.make_task() assert isinstance(task, Task) - with pytest.raises(Exception): + with pytest.raises(DummyException): task(debug=False, handle_exceptions=False) diff --git a/zetta_utils/__init__.py b/zetta_utils/__init__.py index a212c31ff..f5d3f5db4 100644 --- a/zetta_utils/__init__.py +++ b/zetta_utils/__init__.py @@ -35,14 +35,18 @@ def _patched_init(self, *args, **kwargs): _patch_gcsfs_for_proxy() -# Set global multiprocessing threshold -MULTIPROCESSING_NUM_TASKS_THRESHOLD = 128 +# Set global multiprocessing context +MULTIPROCESSING_CONTEXT = "spawn" -# Set start method to `forkserver` if not set elsewhere -# If not set here, `get_start_method` will set the default -# to `fork` w/o allow_none and cause issues with dependencies. +# Set start method to `spawn` if not set elsewhere. +# `fork` is unsafe after gRPC/CUDA initialization; `spawn` avoids this. if multiprocessing.get_start_method(allow_none=True) is None: - multiprocessing.set_start_method("forkserver") + multiprocessing.set_start_method(MULTIPROCESSING_CONTEXT) + + +def get_mp_context() -> multiprocessing.context.BaseContext: + """Get the multiprocessing context for the configured start method.""" + return multiprocessing.get_context(MULTIPROCESSING_CONTEXT) if "sphinx" not in sys.modules: # pragma: no cover import pdbp # noqa diff --git a/zetta_utils/geometry/bbox_strider.py b/zetta_utils/geometry/bbox_strider.py index 68288f9a1..b5e387f8f 100644 --- a/zetta_utils/geometry/bbox_strider.py +++ b/zetta_utils/geometry/bbox_strider.py @@ -1,15 +1,13 @@ # pylint: disable=missing-docstring, no-else-raise from __future__ import annotations -import multiprocessing from math import ceil, floor from typing import List, Literal, Optional, Tuple import attrs from typeguard import typechecked -from zetta_utils import MULTIPROCESSING_NUM_TASKS_THRESHOLD, builder, log -from zetta_utils.common import reset_signal_handlers +from zetta_utils import builder, log from zetta_utils.geometry.vec import VEC3D_PRECISION from . import Vec3D @@ -238,16 +236,7 @@ def shape(self) -> Vec3D[int]: # pragma: no cover def get_all_chunk_bboxes(self) -> List[BBox3D]: """Get all of the chunks.""" - if self.num_chunks > MULTIPROCESSING_NUM_TASKS_THRESHOLD: - with multiprocessing.get_context("fork").Pool( - initializer=reset_signal_handlers - ) as pool_obj: - result = pool_obj.map(self.get_nth_chunk_bbox, range(self.num_chunks)) - else: - result = [ - self.get_nth_chunk_bbox(i) for i in range(self.num_chunks) - ] # TODO: generator? - return result + return [self.get_nth_chunk_bbox(i) for i in range(self.num_chunks)] def _get_atomic_bbox(self, steps_along_dim: Vec3D[int]) -> BBox3D: if self.mode in ("shrink", "expand"): diff --git a/zetta_utils/mazepa/tasks.py b/zetta_utils/mazepa/tasks.py index 545cd720f..b1a3267f7 100644 --- a/zetta_utils/mazepa/tasks.py +++ b/zetta_utils/mazepa/tasks.py @@ -23,7 +23,7 @@ from pebble import concurrent from typing_extensions import ParamSpec -from zetta_utils import log +from zetta_utils import get_mp_context, log from . import constants, exceptions, id_generation from .task_outcome import TaskOutcome, TaskStatus @@ -75,9 +75,9 @@ def _call_task_fn(self, debug: bool = True) -> R_co: if debug or self.runtime_limit_sec is None: return_value = self.fn(*self.args, **self.kwargs) else: - future = concurrent.process(timeout=self.runtime_limit_sec)(self.fn)( - *self.args, **self.kwargs - ) + future = concurrent.process( + timeout=self.runtime_limit_sec, mp_context=get_mp_context() + )(self.fn)(*self.args, **self.kwargs) try: return_value = future.result() except PebbleTimeoutError as e: diff --git a/zetta_utils/mazepa_addons/configurations/worker_pool.py b/zetta_utils/mazepa_addons/configurations/worker_pool.py index 9a4394b34..94754ee3d 100644 --- a/zetta_utils/mazepa_addons/configurations/worker_pool.py +++ b/zetta_utils/mazepa_addons/configurations/worker_pool.py @@ -8,7 +8,7 @@ import pebble -from zetta_utils import builder, log +from zetta_utils import builder, get_mp_context, log from zetta_utils.common import monitor_resources from zetta_utils.mazepa import SemaphoreType, Task, configure_semaphores, run_worker from zetta_utils.mazepa.pool_activity import PoolActivityTracker @@ -73,9 +73,7 @@ def setup_local_worker_pool( ) pool = pebble.ProcessPool( max_workers=num_procs, - context=multiprocessing.get_context( - "spawn" - ), # 'fork' has issues with CV sharded reads + context=get_mp_context(), initializer=worker_init, initargs=[ current_log_level, # log_level diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py index 1957d1cf7..f2764589c 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py @@ -1,7 +1,6 @@ from __future__ import annotations import itertools -import multiprocessing from abc import ABC from copy import deepcopy from os import path @@ -24,8 +23,7 @@ from typeguard import suppress_type_checks from typing_extensions import ParamSpec -from zetta_utils import MULTIPROCESSING_NUM_TASKS_THRESHOLD, log, mazepa -from zetta_utils.common import reset_signal_handlers +from zetta_utils import log, mazepa from zetta_utils.geometry import Vec3D from zetta_utils.layer.volumetric import ( VolumetricBasedLayerProtocol, @@ -520,28 +518,15 @@ def _create_tasks( if self.task_stack_size is None or self.task_stack_size == 1: # No stacking, create one task per index - if len(idx_chunks_flat) > MULTIPROCESSING_NUM_TASKS_THRESHOLD: - with multiprocessing.get_context("fork").Pool( - initializer=reset_signal_handlers - ) as pool_obj: - tasks = pool_obj.map( - self._make_task, - zip( - idx_chunks_flat, - itertools.repeat(dst), - itertools.repeat(op_kwargs), - ), - ) - else: - tasks = list( - map( - self._make_task, - zip( - idx_chunks_flat, - itertools.repeat(dst), - itertools.repeat(op_kwargs), - ), - ) + tasks = list( + map( + self._make_task, + zip( + idx_chunks_flat, + itertools.repeat(dst), + itertools.repeat(op_kwargs), + ), + ) ) else: # Batching with stacked operations