Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions dbos/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from dbos._app_db import ApplicationDatabase
from dbos._context import MaxPriority, MinPriority
from dbos._core import DEFAULT_POLLING_INTERVAL
from dbos._sys_db import SystemDatabase
from dbos._utils import generate_uuid

Expand Down Expand Up @@ -85,8 +86,12 @@ def __init__(self, workflow_id: str, sys_db: SystemDatabase):
def get_workflow_id(self) -> str:
return self.workflow_id

def get_result(self) -> R:
res: R = self._sys_db.await_workflow_result(self.workflow_id)
def get_result(
self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL
) -> R:
res: R = self._sys_db.await_workflow_result(
self.workflow_id, polling_interval_sec
)
return res

def get_status(self) -> WorkflowStatus:
Expand All @@ -105,9 +110,11 @@ def __init__(self, workflow_id: str, sys_db: SystemDatabase):
def get_workflow_id(self) -> str:
return self.workflow_id

async def get_result(self) -> R:
async def get_result(
self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL
) -> R:
res: R = await asyncio.to_thread(
self._sys_db.await_workflow_result, self.workflow_id
self._sys_db.await_workflow_result, self.workflow_id, polling_interval_sec
)
return res

Expand Down
35 changes: 26 additions & 9 deletions dbos/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@

TEMP_SEND_WF_NAME = "<temp>.temp_send_workflow"
DEBOUNCER_WORKFLOW_NAME = "_dbos_debouncer_workflow"
DEFAULT_POLLING_INTERVAL = 1.0


class WorkflowHandleFuture(Generic[R]):
Expand All @@ -103,7 +104,9 @@ def __init__(self, workflow_id: str, future: Future[R], dbos: "DBOS"):
def get_workflow_id(self) -> str:
return self.workflow_id

def get_result(self) -> R:
def get_result(
self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL
) -> R:
try:
r = self.future.result()
except Exception as e:
Expand All @@ -130,9 +133,13 @@ def __init__(self, workflow_id: str, dbos: "DBOS"):
def get_workflow_id(self) -> str:
return self.workflow_id

def get_result(self) -> R:
def get_result(
self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL
) -> R:
try:
r: R = self.dbos._sys_db.await_workflow_result(self.workflow_id)
r: R = self.dbos._sys_db.await_workflow_result(
self.workflow_id, polling_interval_sec
)
except Exception as e:
serialized_e = self.dbos._serializer.serialize(e)
self.dbos._sys_db.record_get_result(self.workflow_id, None, serialized_e)
Expand All @@ -158,7 +165,9 @@ def __init__(self, workflow_id: str, task: asyncio.Future[R], dbos: "DBOS"):
def get_workflow_id(self) -> str:
return self.workflow_id

async def get_result(self) -> R:
async def get_result(
self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL
) -> R:
try:
r = await self.task
except Exception as e:
Expand Down Expand Up @@ -192,10 +201,14 @@ def __init__(self, workflow_id: str, dbos: "DBOS"):
def get_workflow_id(self) -> str:
return self.workflow_id

async def get_result(self) -> R:
async def get_result(
self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL
) -> R:
try:
r: R = await asyncio.to_thread(
self.dbos._sys_db.await_workflow_result, self.workflow_id
self.dbos._sys_db.await_workflow_result,
self.workflow_id,
polling_interval_sec,
)
except Exception as e:
serialized_e = self.dbos._serializer.serialize(e)
Expand Down Expand Up @@ -366,7 +379,7 @@ def persist(func: Callable[[], R]) -> R:
)
# Directly return the result if the workflow is already completed
recorded_result: R = dbos._sys_db.await_workflow_result(
status["workflow_uuid"]
status["workflow_uuid"], polling_interval=DEFAULT_POLLING_INTERVAL
)
return recorded_result
try:
Expand All @@ -381,7 +394,9 @@ def persist(func: Callable[[], R]) -> R:
return output
except DBOSWorkflowConflictIDError:
# Await the workflow result
r: R = dbos._sys_db.await_workflow_result(status["workflow_uuid"])
r: R = dbos._sys_db.await_workflow_result(
status["workflow_uuid"], polling_interval=DEFAULT_POLLING_INTERVAL
)
return r
except DBOSWorkflowCancelledError as error:
raise DBOSAwaitedWorkflowCancelledError(status["workflow_uuid"])
Expand Down Expand Up @@ -788,7 +803,9 @@ def recorded_result(
c_wfid: str, dbos: "DBOS"
) -> Callable[[Callable[[], R]], R]:
def recorded_result_inner(func: Callable[[], R]) -> R:
r: R = dbos._sys_db.await_workflow_result(c_wfid)
r: R = dbos._sys_db.await_workflow_result(
c_wfid, polling_interval=DEFAULT_POLLING_INTERVAL
)
return r

return recorded_result_inner
Expand Down
11 changes: 9 additions & 2 deletions dbos/_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from ._classproperty import classproperty
from ._core import (
DEBOUNCER_WORKFLOW_NAME,
DEFAULT_POLLING_INTERVAL,
TEMP_SEND_WF_NAME,
WorkflowHandleAsyncPolling,
WorkflowHandlePolling,
Expand Down Expand Up @@ -335,6 +336,8 @@ def __init__(
self._executor_field: Optional[ThreadPoolExecutor] = None
self._background_threads: List[threading.Thread] = []
self.conductor_url: Optional[str] = conductor_url
if config.get("conductor_url"):
self.conductor_url = config.get("conductor_url")
self.conductor_key: Optional[str] = conductor_key
if config.get("conductor_key"):
self.conductor_key = config.get("conductor_key")
Expand Down Expand Up @@ -1551,7 +1554,9 @@ def get_workflow_id(self) -> str:
"""Return the applicable workflow ID."""
...

def get_result(self) -> R:
def get_result(
self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL
) -> R:
"""Return the result of the workflow function invocation, waiting if necessary."""
...

Expand Down Expand Up @@ -1580,7 +1585,9 @@ def get_workflow_id(self) -> str:
"""Return the applicable workflow ID."""
...

async def get_result(self) -> R:
async def get_result(
self, *, polling_interval_sec: float = DEFAULT_POLLING_INTERVAL
) -> R:
"""Return the result of the workflow function invocation, waiting if necessary."""
...

Expand Down
2 changes: 2 additions & 0 deletions dbos/_dbos_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class DBOSConfig(TypedDict, total=False):
enable_otlp (bool): If True, enable built-in DBOS OTLP tracing and logging.
system_database_engine (sa.Engine): A custom system database engine. If provided, DBOS will not create an engine but use this instead.
conductor_key (str): An API key for DBOS Conductor. Pass this in to connect your process to Conductor.
conductor_url (str): The websockets URL for your DBOS Conductor service. Only set if you're self-hosting Conductor.
serializer (Serializer): A custom serializer and deserializer DBOS uses when storing program data in the system database
"""

Expand All @@ -60,6 +61,7 @@ class DBOSConfig(TypedDict, total=False):
enable_otlp: Optional[bool]
system_database_engine: Optional[sa.Engine]
conductor_key: Optional[str]
conductor_url: Optional[str]
serializer: Optional[Serializer]


Expand Down
136 changes: 96 additions & 40 deletions dbos/_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
worker_concurrency: Optional[int] = None,
priority_enabled: bool = False,
partition_queue: bool = False,
polling_interval_sec: float = 1.0,
) -> None:
if (
worker_concurrency is not None
Expand All @@ -53,12 +54,15 @@ def __init__(
raise ValueError(
"worker_concurrency must be less than or equal to concurrency"
)
if polling_interval_sec <= 0.0:
raise ValueError("polling_interval_sec must be positive")
self.name = name
self.concurrency = concurrency
self.worker_concurrency = worker_concurrency
self.limiter = limiter
self.priority_enabled = priority_enabled
self.partition_queue = partition_queue
self.polling_interval_sec = polling_interval_sec
from ._dbos import _get_or_create_dbos_registry

registry = _get_or_create_dbos_registry()
Expand Down Expand Up @@ -108,50 +112,102 @@ async def enqueue_async(
return await start_workflow_async(dbos, func, self.name, False, *args, **kwargs)


def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None:
polling_interval = 1.0
min_polling_interval = 1.0
max_polling_interval = 120.0
def queue_worker_thread(
stop_event: threading.Event, dbos: "DBOS", queue: Queue
) -> None:
"""Worker thread for processing a single queue."""
polling_interval = queue.polling_interval_sec
min_polling_interval = queue.polling_interval_sec
max_polling_interval = max(queue.polling_interval_sec, 120.0)

while not stop_event.is_set():
# Wait for the polling interval with jitter
if stop_event.wait(timeout=polling_interval * random.uniform(0.95, 1.05)):
return
queues = dict(dbos._registry.queue_info_map)
for _, queue in queues.items():
try:
if queue.partition_queue:
dequeued_workflows = []
queue_partition_keys = dbos._sys_db.get_queue_partitions(queue.name)
for key in queue_partition_keys:
dequeued_workflows += dbos._sys_db.start_queued_workflows(
queue,
GlobalParams.executor_id,
GlobalParams.app_version,
key,
)
else:
dequeued_workflows = dbos._sys_db.start_queued_workflows(
queue, GlobalParams.executor_id, GlobalParams.app_version, None
)
for id in dequeued_workflows:
execute_workflow_by_id(dbos, id)
except OperationalError as e:
if isinstance(
e.orig, (errors.SerializationFailure, errors.LockNotAvailable)
):
# If a serialization error is encountered, increase the polling interval
polling_interval = min(
max_polling_interval,
polling_interval * 2.0,
)
dbos.logger.warning(
f"Contention detected in queue thread for {queue.name}. Increasing polling interval to {polling_interval:.2f}."

try:
if queue.partition_queue:
dequeued_workflows = []
queue_partition_keys = dbos._sys_db.get_queue_partitions(queue.name)
for key in queue_partition_keys:
dequeued_workflows += dbos._sys_db.start_queued_workflows(
queue,
GlobalParams.executor_id,
GlobalParams.app_version,
key,
)
else:
dbos.logger.warning(f"Exception encountered in queue thread: {e}")
except Exception as e:
if not stop_event.is_set():
# Only print the error if the thread is not stopping
dbos.logger.warning(f"Exception encountered in queue thread: {e}")
else:
dequeued_workflows = dbos._sys_db.start_queued_workflows(
queue, GlobalParams.executor_id, GlobalParams.app_version, None
)
for id in dequeued_workflows:
execute_workflow_by_id(dbos, id)
except OperationalError as e:
if isinstance(
e.orig, (errors.SerializationFailure, errors.LockNotAvailable)
):
# If a serialization error is encountered, increase the polling interval
polling_interval = min(
max_polling_interval,
polling_interval * 2.0,
)
dbos.logger.warning(
f"Contention detected in queue thread for {queue.name}. Increasing polling interval to {polling_interval:.2f}."
)
else:
dbos.logger.warning(
f"Exception encountered in queue thread for {queue.name}: {e}"
)
except Exception as e:
if not stop_event.is_set():
# Only print the error if the thread is not stopping
dbos.logger.warning(
f"Exception encountered in queue thread for {queue.name}: {e}"
)

# Attempt to scale back the polling interval on each iteration
polling_interval = max(min_polling_interval, polling_interval * 0.9)


def queue_thread(stop_event: threading.Event, dbos: "DBOS") -> None:
"""Main queue manager thread that spawns and monitors worker threads for each queue."""
queue_threads: dict[str, threading.Thread] = {}
check_interval = 1.0 # Check for new queues every second

while not stop_event.is_set():
# Check for new queues
current_queues = dict(dbos._registry.queue_info_map)

# Start threads for new queues
for queue_name, queue in current_queues.items():
if (
queue_name not in queue_threads
or not queue_threads[queue_name].is_alive()
):
thread = threading.Thread(
target=queue_worker_thread,
args=(stop_event, dbos, queue),
name=f"queue-worker-{queue_name}",
daemon=True,
)
thread.start()
queue_threads[queue_name] = thread
dbos.logger.debug(f"Started worker thread for queue: {queue_name}")

# Wait for the check interval or stop event
if stop_event.wait(timeout=check_interval):
break

# Join all queue worker threads
dbos.logger.info("Stopping queue manager, joining all worker threads...")
for queue_name, thread in queue_threads.items():
if thread.is_alive():
thread.join(timeout=10.0) # Give each thread 10 seconds to finish
if thread.is_alive():
dbos.logger.debug(
f"Queue worker thread for {queue_name} did not stop in time"
)
else:
dbos.logger.debug(
f"Queue worker thread for {queue_name} stopped successfully"
)
13 changes: 7 additions & 6 deletions dbos/_sys_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ def get_deduplicated_workflow(
return workflow_id

@db_retry()
def await_workflow_result(self, workflow_id: str) -> Any:
def await_workflow_result(self, workflow_id: str, polling_interval: float) -> Any:
while True:
with self.engine.begin() as c:
row = c.execute(
Expand All @@ -955,7 +955,7 @@ def await_workflow_result(self, workflow_id: str) -> Any:
raise DBOSAwaitedWorkflowCancelledError(workflow_id)
else:
pass # CB: I guess we're assuming the WF will show up eventually.
time.sleep(1)
time.sleep(polling_interval)

def get_workflows(
self,
Expand Down Expand Up @@ -998,11 +998,12 @@ def get_workflows(

if input.queues_only:
query = sa.select(*load_columns).where(
sa.and_(
SystemSchema.workflow_status.c.queue_name.isnot(None),
SystemSchema.workflow_status.c.status.in_(["ENQUEUED", "PENDING"]),
)
SystemSchema.workflow_status.c.queue_name.isnot(None),
)
if not input.status:
query = query.where(
SystemSchema.workflow_status.c.status.in_(["ENQUEUED", "PENDING"])
)
else:
query = sa.select(*load_columns)
if input.sort_desc:
Expand Down
Loading
Loading