diff --git a/scheduler/helpers/timeouts.py b/scheduler/helpers/timeouts.py index 298ceb6..76dcedc 100644 --- a/scheduler/helpers/timeouts.py +++ b/scheduler/helpers/timeouts.py @@ -121,3 +121,11 @@ def cancel_death_penalty(self): return self._timer.cancel() self._timer = None + + +def get_default_death_penalty_class() -> type[BaseDeathPenalty]: + """Returns the default death penalty class based on the platform.""" + if hasattr(signal, "SIGALRM"): + return UnixSignalDeathPenalty + else: + return TimerDeathPenalty diff --git a/scheduler/types/settings_types.py b/scheduler/types/settings_types.py index 3ec1478..0e1e3ed 100644 --- a/scheduler/types/settings_types.py +++ b/scheduler/types/settings_types.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Callable, Dict, Optional, List, Tuple, Any, Type, ClassVar, Set -from scheduler.helpers.timeouts import BaseDeathPenalty, UnixSignalDeathPenalty +from scheduler.helpers.timeouts import BaseDeathPenalty, get_default_death_penalty_class if sys.version_info >= (3, 11): from typing import Self @@ -40,7 +40,7 @@ class SchedulerConfiguration: DEFAULT_MAINTENANCE_TASK_INTERVAL: int = 10 * 60 # The interval to run maintenance tasks in seconds. 10 minutes. DEFAULT_JOB_MONITORING_INTERVAL: int = 30 # The interval to monitor jobs in seconds. SCHEDULER_FALLBACK_PERIOD_SECS: int = 120 # Period (secs) to wait before requiring to reacquire locks - DEATH_PENALTY_CLASS: Type[BaseDeathPenalty] = UnixSignalDeathPenalty + DEATH_PENALTY_CLASS: Type[BaseDeathPenalty] = get_default_death_penalty_class() @dataclass(slots=True, frozen=True, kw_only=True) diff --git a/scheduler/worker/worker.py b/scheduler/worker/worker.py index b664ce6..1ee0fcc 100644 --- a/scheduler/worker/worker.py +++ b/scheduler/worker/worker.py @@ -649,7 +649,7 @@ def execute_job(self, job: JobModel, queue: Queue) -> None: The worker will wait for the job execution process and make sure it executes within the given timeout bounds, or will end the job execution process with SIGALRM. """ - if self.fork_job_execution: + if hasattr(os, "fork") and self.fork_job_execution: self._model.set_field("state", WorkerStatus.BUSY, connection=self.connection) self.fork_job_execution_process(job, queue) self.monitor_job_execution_process(job, queue) @@ -839,7 +839,7 @@ def _ensure_list(obj: Any) -> List[Any]: def _calc_worker_name(existing_worker_names: Collection[str]) -> str: - hostname = os.uname()[1] + hostname = socket.gethostname() c = 1 worker_name = f"{hostname}-worker.{c}" while worker_name in existing_worker_names: