Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions tests/unit/geometry/test_bbox_strider.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest

from zetta_utils import MULTIPROCESSING_NUM_TASKS_THRESHOLD
from zetta_utils.geometry import BBox3D, BBoxStrider, IntVec3D, Vec3D


Expand Down Expand Up @@ -56,11 +55,11 @@ def test_bbox_strider_get_all_chunks(mocker):
]


def test_bbox_strider_get_all_chunks_parallel(mocker):
def test_bbox_strider_get_all_chunks_large(mocker):
strider = BBoxStrider(
bbox=BBox3D.from_coords(
start_coord=Vec3D(0, 0, 0),
end_coord=Vec3D(2, 1, MULTIPROCESSING_NUM_TASKS_THRESHOLD + 1),
end_coord=Vec3D(2, 1, 130),
resolution=Vec3D(1, 1, 1),
),
chunk_size=IntVec3D(1, 1, 1),
Expand Down
29 changes: 16 additions & 13 deletions tests/unit/mazepa/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
from ..helpers import DummyException


# Module-level functions for spawn-compatible pebble process execution
def _slow_task_fn():
time.sleep(0.3)


def _raising_task_fn():
raise DummyException()


def test_make_taskable_operation_cls() -> None:
@taskable_operation_cls(operation_name="OpDummyClass1")
@attrs.mutable
Expand Down Expand Up @@ -56,24 +65,18 @@ def dummy_task_fn():


def test_task_runtime_limit() -> None:
@taskable_operation(runtime_limit_sec=0.1)
def dummy_task_fn():
time.sleep(0.3)

assert isinstance(dummy_task_fn, TaskableOperation)
task = dummy_task_fn.make_task()
slow_task_op = taskable_operation(runtime_limit_sec=0.1)(_slow_task_fn)
assert isinstance(slow_task_op, TaskableOperation)
task = slow_task_op.make_task()
assert isinstance(task, Task)
outcome = task(debug=False)
assert isinstance(outcome.exception, MazepaTimeoutError)


def test_task_no_handle_exc() -> None:
@taskable_operation(runtime_limit_sec=0.1)
def dummy_task_fn():
raise DummyException()

assert isinstance(dummy_task_fn, TaskableOperation)
task = dummy_task_fn.make_task()
raising_task_op = taskable_operation(runtime_limit_sec=10)(_raising_task_fn)
assert isinstance(raising_task_op, TaskableOperation)
task = raising_task_op.make_task()
assert isinstance(task, Task)
with pytest.raises(Exception):
with pytest.raises(DummyException):
task(debug=False, handle_exceptions=False)
16 changes: 10 additions & 6 deletions zetta_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@ def _patched_init(self, *args, **kwargs):
_patch_gcsfs_for_proxy()


# Set global multiprocessing threshold
MULTIPROCESSING_NUM_TASKS_THRESHOLD = 128
# Set global multiprocessing context
MULTIPROCESSING_CONTEXT = "spawn"

# Set start method to `forkserver` if not set elsewhere
# If not set here, `get_start_method` will set the default
# to `fork` w/o allow_none and cause issues with dependencies.
# Set start method to `spawn` if not set elsewhere.
# `fork` is unsafe after gRPC/CUDA initialization; `spawn` avoids this.
if multiprocessing.get_start_method(allow_none=True) is None:
multiprocessing.set_start_method("forkserver")
multiprocessing.set_start_method(MULTIPROCESSING_CONTEXT)


def get_mp_context() -> multiprocessing.context.BaseContext:
"""Get the multiprocessing context for the configured start method."""
return multiprocessing.get_context(MULTIPROCESSING_CONTEXT)

if "sphinx" not in sys.modules: # pragma: no cover
import pdbp # noqa
Expand Down
15 changes: 2 additions & 13 deletions zetta_utils/geometry/bbox_strider.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
# pylint: disable=missing-docstring, no-else-raise
from __future__ import annotations

import multiprocessing
from math import ceil, floor
from typing import List, Literal, Optional, Tuple

import attrs
from typeguard import typechecked

from zetta_utils import MULTIPROCESSING_NUM_TASKS_THRESHOLD, builder, log
from zetta_utils.common import reset_signal_handlers
from zetta_utils import builder, log
from zetta_utils.geometry.vec import VEC3D_PRECISION

from . import Vec3D
Expand Down Expand Up @@ -238,16 +236,7 @@ def shape(self) -> Vec3D[int]: # pragma: no cover

def get_all_chunk_bboxes(self) -> List[BBox3D]:
"""Get all of the chunks."""
if self.num_chunks > MULTIPROCESSING_NUM_TASKS_THRESHOLD:
with multiprocessing.get_context("fork").Pool(
initializer=reset_signal_handlers
) as pool_obj:
result = pool_obj.map(self.get_nth_chunk_bbox, range(self.num_chunks))
else:
result = [
self.get_nth_chunk_bbox(i) for i in range(self.num_chunks)
] # TODO: generator?
return result
return [self.get_nth_chunk_bbox(i) for i in range(self.num_chunks)]

def _get_atomic_bbox(self, steps_along_dim: Vec3D[int]) -> BBox3D:
if self.mode in ("shrink", "expand"):
Expand Down
8 changes: 4 additions & 4 deletions zetta_utils/mazepa/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pebble import concurrent
from typing_extensions import ParamSpec

from zetta_utils import log
from zetta_utils import get_mp_context, log

from . import constants, exceptions, id_generation
from .task_outcome import TaskOutcome, TaskStatus
Expand Down Expand Up @@ -75,9 +75,9 @@ def _call_task_fn(self, debug: bool = True) -> R_co:
if debug or self.runtime_limit_sec is None:
return_value = self.fn(*self.args, **self.kwargs)
else:
future = concurrent.process(timeout=self.runtime_limit_sec)(self.fn)(
*self.args, **self.kwargs
)
future = concurrent.process(
timeout=self.runtime_limit_sec, mp_context=get_mp_context()
)(self.fn)(*self.args, **self.kwargs)
try:
return_value = future.result()
except PebbleTimeoutError as e:
Expand Down
6 changes: 2 additions & 4 deletions zetta_utils/mazepa_addons/configurations/worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pebble

from zetta_utils import builder, log
from zetta_utils import builder, get_mp_context, log
from zetta_utils.common import monitor_resources
from zetta_utils.mazepa import SemaphoreType, Task, configure_semaphores, run_worker
from zetta_utils.mazepa.pool_activity import PoolActivityTracker
Expand Down Expand Up @@ -73,9 +73,7 @@ def setup_local_worker_pool(
)
pool = pebble.ProcessPool(
max_workers=num_procs,
context=multiprocessing.get_context(
"spawn"
), # 'fork' has issues with CV sharded reads
context=get_mp_context(),
initializer=worker_init,
initargs=[
current_log_level, # log_level
Expand Down
35 changes: 10 additions & 25 deletions zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import itertools
import multiprocessing
from abc import ABC
from copy import deepcopy
from os import path
Expand All @@ -24,8 +23,7 @@
from typeguard import suppress_type_checks
from typing_extensions import ParamSpec

from zetta_utils import MULTIPROCESSING_NUM_TASKS_THRESHOLD, log, mazepa
from zetta_utils.common import reset_signal_handlers
from zetta_utils import log, mazepa
from zetta_utils.geometry import Vec3D
from zetta_utils.layer.volumetric import (
VolumetricBasedLayerProtocol,
Expand Down Expand Up @@ -520,28 +518,15 @@ def _create_tasks(

if self.task_stack_size is None or self.task_stack_size == 1:
# No stacking, create one task per index
if len(idx_chunks_flat) > MULTIPROCESSING_NUM_TASKS_THRESHOLD:
with multiprocessing.get_context("fork").Pool(
initializer=reset_signal_handlers
) as pool_obj:
tasks = pool_obj.map(
self._make_task,
zip(
idx_chunks_flat,
itertools.repeat(dst),
itertools.repeat(op_kwargs),
),
)
else:
tasks = list(
map(
self._make_task,
zip(
idx_chunks_flat,
itertools.repeat(dst),
itertools.repeat(op_kwargs),
),
)
tasks = list(
map(
self._make_task,
zip(
idx_chunks_flat,
itertools.repeat(dst),
itertools.repeat(op_kwargs),
),
)
)
else:
# Batching with stacked operations
Expand Down
Loading