diff --git a/dbos/_client.py b/dbos/_client.py index 15ca2bed..78b16af1 100644 --- a/dbos/_client.py +++ b/dbos/_client.py @@ -18,6 +18,7 @@ from dbos._app_db import ApplicationDatabase from dbos._context import MaxPriority, MinPriority +from dbos._core import DEFAULT_POLLING_INTERVAL from dbos._sys_db import SystemDatabase from dbos._utils import generate_uuid @@ -85,8 +86,12 @@ def __init__(self, workflow_id: str, sys_db: SystemDatabase): def get_workflow_id(self) -> str: return self.workflow_id - def get_result(self) -> R: - res: R = self._sys_db.await_workflow_result(self.workflow_id) + def get_result( + self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL + ) -> R: + res: R = self._sys_db.await_workflow_result( + self.workflow_id, polling_interval_sec + ) return res def get_status(self) -> WorkflowStatus: @@ -105,9 +110,11 @@ def __init__(self, workflow_id: str, sys_db: SystemDatabase): def get_workflow_id(self) -> str: return self.workflow_id - async def get_result(self) -> R: + async def get_result( + self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL + ) -> R: res: R = await asyncio.to_thread( - self._sys_db.await_workflow_result, self.workflow_id + self._sys_db.await_workflow_result, self.workflow_id, polling_interval_sec ) return res diff --git a/dbos/_core.py b/dbos/_core.py index 133b8c9f..4df334ce 100644 --- a/dbos/_core.py +++ b/dbos/_core.py @@ -91,6 +91,7 @@ TEMP_SEND_WF_NAME = ".temp_send_workflow" DEBOUNCER_WORKFLOW_NAME = "_dbos_debouncer_workflow" +DEFAULT_POLLING_INTERVAL = 1.0 class WorkflowHandleFuture(Generic[R]): @@ -103,7 +104,9 @@ def __init__(self, workflow_id: str, future: Future[R], dbos: "DBOS"): def get_workflow_id(self) -> str: return self.workflow_id - def get_result(self) -> R: + def get_result( + self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL + ) -> R: try: r = self.future.result() except Exception as e: @@ -130,9 +133,13 @@ def __init__(self, workflow_id: str, dbos: "DBOS"): def get_workflow_id(self) -> str: return self.workflow_id - def get_result(self) -> R: + def get_result( + self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL + ) -> R: try: - r: R = self.dbos._sys_db.await_workflow_result(self.workflow_id) + r: R = self.dbos._sys_db.await_workflow_result( + self.workflow_id, polling_interval_sec + ) except Exception as e: serialized_e = self.dbos._serializer.serialize(e) self.dbos._sys_db.record_get_result(self.workflow_id, None, serialized_e) @@ -158,7 +165,9 @@ def __init__(self, workflow_id: str, task: asyncio.Future[R], dbos: "DBOS"): def get_workflow_id(self) -> str: return self.workflow_id - async def get_result(self) -> R: + async def get_result( + self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL + ) -> R: try: r = await self.task except Exception as e: @@ -192,10 +201,14 @@ def __init__(self, workflow_id: str, dbos: "DBOS"): def get_workflow_id(self) -> str: return self.workflow_id - async def get_result(self) -> R: + async def get_result( + self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL + ) -> R: try: r: R = await asyncio.to_thread( - self.dbos._sys_db.await_workflow_result, self.workflow_id + self.dbos._sys_db.await_workflow_result, + self.workflow_id, + polling_interval_sec, ) except Exception as e: serialized_e = self.dbos._serializer.serialize(e) @@ -366,7 +379,7 @@ def persist(func: Callable[[], R]) -> R: ) # Directly return the result if the workflow is already completed recorded_result: R = dbos._sys_db.await_workflow_result( - status["workflow_uuid"] + status["workflow_uuid"], polling_interval=DEFAULT_POLLING_INTERVAL ) return recorded_result try: @@ -381,7 +394,9 @@ def persist(func: Callable[[], R]) -> R: return output except DBOSWorkflowConflictIDError: # Await the workflow result - r: R = dbos._sys_db.await_workflow_result(status["workflow_uuid"]) + r: R = dbos._sys_db.await_workflow_result( + status["workflow_uuid"], polling_interval=DEFAULT_POLLING_INTERVAL + ) return r except DBOSWorkflowCancelledError as error: raise DBOSAwaitedWorkflowCancelledError(status["workflow_uuid"]) @@ -788,7 +803,9 @@ def recorded_result( c_wfid: str, dbos: "DBOS" ) -> Callable[[Callable[[], R]], R]: def recorded_result_inner(func: Callable[[], R]) -> R: - r: R = dbos._sys_db.await_workflow_result(c_wfid) + r: R = dbos._sys_db.await_workflow_result( + c_wfid, polling_interval=DEFAULT_POLLING_INTERVAL + ) return r return recorded_result_inner diff --git a/dbos/_dbos.py b/dbos/_dbos.py index a236ddd3..09455ba6 100644 --- a/dbos/_dbos.py +++ b/dbos/_dbos.py @@ -38,6 +38,7 @@ from ._classproperty import classproperty from ._core import ( DEBOUNCER_WORKFLOW_NAME, + DEFAULT_POLLING_INTERVAL, TEMP_SEND_WF_NAME, WorkflowHandleAsyncPolling, WorkflowHandlePolling, @@ -335,6 +336,8 @@ def __init__( self._executor_field: Optional[ThreadPoolExecutor] = None self._background_threads: List[threading.Thread] = [] self.conductor_url: Optional[str] = conductor_url + if config.get("conductor_url"): + self.conductor_url = config.get("conductor_url") self.conductor_key: Optional[str] = conductor_key if config.get("conductor_key"): self.conductor_key = config.get("conductor_key") @@ -1551,7 +1554,9 @@ def get_workflow_id(self) -> str: """Return the applicable workflow ID.""" ... - def get_result(self) -> R: + def get_result( + self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL + ) -> R: """Return the result of the workflow function invocation, waiting if necessary.""" ... @@ -1580,7 +1585,9 @@ def get_workflow_id(self) -> str: """Return the applicable workflow ID.""" ... - async def get_result(self) -> R: + async def get_result( + self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL + ) -> R: """Return the result of the workflow function invocation, waiting if necessary.""" ... diff --git a/dbos/_dbos_config.py b/dbos/_dbos_config.py index 3027eec4..85ff7552 100644 --- a/dbos/_dbos_config.py +++ b/dbos/_dbos_config.py @@ -39,6 +39,7 @@ class DBOSConfig(TypedDict, total=False): enable_otlp (bool): If True, enable built-in DBOS OTLP tracing and logging. system_database_engine (sa.Engine): A custom system database engine. If provided, DBOS will not create an engine but use this instead. conductor_key (str): An API key for DBOS Conductor. Pass this in to connect your process to Conductor. + conductor_url (str): The websockets URL for your DBOS Conductor service. Only set if you're self-hosting Conductor. serializer (Serializer): A custom serializer and deserializer DBOS uses when storing program data in the system database """ @@ -60,6 +61,7 @@ class DBOSConfig(TypedDict, total=False): enable_otlp: Optional[bool] system_database_engine: Optional[sa.Engine] conductor_key: Optional[str] + conductor_url: Optional[str] serializer: Optional[Serializer] diff --git a/dbos/_queue.py b/dbos/_queue.py index a45701aa..2a9002c5 100644 --- a/dbos/_queue.py +++ b/dbos/_queue.py @@ -44,6 +44,7 @@ def __init__( worker_concurrency: Optional[int] = None, priority_enabled: bool = False, partition_queue: bool = False, + polling_interval_sec: float = 1.0, ) -> None: if ( worker_concurrency is not None @@ -53,12 +54,15 @@ def __init__( raise ValueError( "worker_concurrency must be less than or equal to concurrency" ) + if polling_interval_sec <= 0.0: + raise ValueError("polling_interval_sec must be positive") self.name = name self.concurrency = concurrency self.worker_concurrency = worker_concurrency self.limiter = limiter self.priority_enabled = priority_enabled self.partition_queue = partition_queue + self.polling_interval_sec = polling_interval_sec from ._dbos import _get_or_create_dbos_registry registry = _get_or_create_dbos_registry() @@ -108,50 +112,102 @@ async def enqueue_async( return await start_workflow_async(dbos, func, self.name, False, *args, **kwargs) -def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None: - polling_interval = 1.0 - min_polling_interval = 1.0 - max_polling_interval = 120.0 +def queue_worker_thread( + stop_event: threading.Event, dbos: "DBOS", queue: Queue +) -> None: + """Worker thread for processing a single queue.""" + polling_interval = queue.polling_interval_sec + min_polling_interval = queue.polling_interval_sec + max_polling_interval = max(queue.polling_interval_sec, 120.0) + while not stop_event.is_set(): # Wait for the polling interval with jitter if stop_event.wait(timeout=polling_interval * random.uniform(0.95, 1.05)): return - queues = dict(dbos._registry.queue_info_map) - for _, queue in queues.items(): - try: - if queue.partition_queue: - dequeued_workflows = [] - queue_partition_keys = dbos._sys_db.get_queue_partitions(queue.name) - for key in queue_partition_keys: - dequeued_workflows += dbos._sys_db.start_queued_workflows( - queue, - GlobalParams.executor_id, - GlobalParams.app_version, - key, - ) - else: - dequeued_workflows = dbos._sys_db.start_queued_workflows( - queue, GlobalParams.executor_id, GlobalParams.app_version, None - ) - for id in dequeued_workflows: - execute_workflow_by_id(dbos, id) - except OperationalError as e: - if isinstance( - e.orig, (errors.SerializationFailure, errors.LockNotAvailable) - ): - # If a serialization error is encountered, increase the polling interval - polling_interval = min( - max_polling_interval, - polling_interval * 2.0, - ) - dbos.logger.warning( - f"Contention detected in queue thread for {queue.name}. Increasing polling interval to {polling_interval:.2f}." + + try: + if queue.partition_queue: + dequeued_workflows = [] + queue_partition_keys = dbos._sys_db.get_queue_partitions(queue.name) + for key in queue_partition_keys: + dequeued_workflows += dbos._sys_db.start_queued_workflows( + queue, + GlobalParams.executor_id, + GlobalParams.app_version, + key, ) - else: - dbos.logger.warning(f"Exception encountered in queue thread: {e}") - except Exception as e: - if not stop_event.is_set(): - # Only print the error if the thread is not stopping - dbos.logger.warning(f"Exception encountered in queue thread: {e}") + else: + dequeued_workflows = dbos._sys_db.start_queued_workflows( + queue, GlobalParams.executor_id, GlobalParams.app_version, None + ) + for id in dequeued_workflows: + execute_workflow_by_id(dbos, id) + except OperationalError as e: + if isinstance( + e.orig, (errors.SerializationFailure, errors.LockNotAvailable) + ): + # If a serialization error is encountered, increase the polling interval + polling_interval = min( + max_polling_interval, + polling_interval * 2.0, + ) + dbos.logger.warning( + f"Contention detected in queue thread for {queue.name}. Increasing polling interval to {polling_interval:.2f}." + ) + else: + dbos.logger.warning( + f"Exception encountered in queue thread for {queue.name}: {e}" + ) + except Exception as e: + if not stop_event.is_set(): + # Only print the error if the thread is not stopping + dbos.logger.warning( + f"Exception encountered in queue thread for {queue.name}: {e}" + ) + # Attempt to scale back the polling interval on each iteration polling_interval = max(min_polling_interval, polling_interval * 0.9) + + +def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None: + """Main queue manager thread that spawns and monitors worker threads for each queue.""" + queue_threads: dict[str, threading.Thread] = {} + check_interval = 1.0 # Check for new queues every second + + while not stop_event.is_set(): + # Check for new queues + current_queues = dict(dbos._registry.queue_info_map) + + # Start threads for new queues + for queue_name, queue in current_queues.items(): + if ( + queue_name not in queue_threads + or not queue_threads[queue_name].is_alive() + ): + thread = threading.Thread( + target=queue_worker_thread, + args=(stop_event, dbos, queue), + name=f"queue-worker-{queue_name}", + daemon=True, + ) + thread.start() + queue_threads[queue_name] = thread + dbos.logger.debug(f"Started worker thread for queue: {queue_name}") + + # Wait for the check interval or stop event + if stop_event.wait(timeout=check_interval): + break + + # Join all queue worker threads + dbos.logger.info("Stopping queue manager, joining all worker threads...") + for queue_name, thread in queue_threads.items(): + if thread.is_alive(): + thread.join(timeout=10.0) # Give each thread 10 seconds to finish + if thread.is_alive(): + dbos.logger.debug( + f"Queue worker thread for {queue_name} did not stop in time" + ) + else: + dbos.logger.debug( + f"Queue worker thread for {queue_name} stopped successfully" + ) diff --git a/dbos/_sys_db.py b/dbos/_sys_db.py index e419c36e..f09544eb 100644 --- a/dbos/_sys_db.py +++ b/dbos/_sys_db.py @@ -930,7 +930,7 @@ def get_deduplicated_workflow( return workflow_id @db_retry() - def await_workflow_result(self, workflow_id: str) -> Any: + def await_workflow_result(self, workflow_id: str, polling_interval: float) -> Any: while True: with self.engine.begin() as c: row = c.execute( @@ -955,7 +955,7 @@ def await_workflow_result(self, workflow_id: str) -> Any: raise DBOSAwaitedWorkflowCancelledError(workflow_id) else: pass # CB: I guess we're assuming the WF will show up eventually. - time.sleep(1) + time.sleep(polling_interval) def get_workflows( self, @@ -998,11 +998,12 @@ def get_workflows( if input.queues_only: query = sa.select(*load_columns).where( - sa.and_( - SystemSchema.workflow_status.c.queue_name.isnot(None), - SystemSchema.workflow_status.c.status.in_(["ENQUEUED", "PENDING"]), - ) + SystemSchema.workflow_status.c.queue_name.isnot(None), ) + if not input.status: + query = query.where( + SystemSchema.workflow_status.c.status.in_(["ENQUEUED", "PENDING"]) + ) else: query = sa.select(*load_columns) if input.sort_desc: diff --git a/tests/test_queue.py b/tests/test_queue.py index 081660f9..a51a17b3 100644 --- a/tests/test_queue.py +++ b/tests/test_queue.py @@ -1695,3 +1695,19 @@ def normal_workflow() -> str: with pytest.raises(Exception): with SetEnqueueOptions(queue_partition_key="test"): partitionless_queue.enqueue(normal_workflow) + + +def test_polling_interval(dbos: DBOS) -> None: + queue = Queue("queue", polling_interval_sec=0.1) + + @DBOS.workflow() + def workflow() -> str: + assert DBOS.workflow_id + return DBOS.workflow_id + + assert queue.enqueue(workflow).get_result() + + for _ in range(10): + start_time = time.time() + assert queue.enqueue(workflow).get_result(polling_interval_sec=0.1) + assert time.time() - start_time < 1.0 diff --git a/tests/test_workflow_introspection.py b/tests/test_workflow_introspection.py index e1fa678c..97bfbaa8 100644 --- a/tests/test_workflow_introspection.py +++ b/tests/test_workflow_introspection.py @@ -364,6 +364,8 @@ def blocking_step(i: int) -> int: assert handle.get_result() == [0, 1, 2, 3, 4] workflows = DBOS.list_queued_workflows() assert len(workflows) == 0 + workflows = DBOS.list_queued_workflows(status="SUCCESS") + assert len(workflows) == queued_steps # Test the steps are listed properly steps = DBOS.list_workflow_steps(handle.workflow_id)