Skip to content
82 changes: 78 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ Task 3 and will run 2 seconds after Task 2 finishes.

Because of how `dramatiq-workflow` is implemented, each task in a workflow has
to know about the remaining tasks in the workflow that could potentially run
after it. When a workflow has a large number of tasks, this can lead to an
after it. By default, this is stored alongside your messages in the message
queue. When a workflow has a large number of tasks, it can lead to an
increase of memory usage in the broker and increased network traffic between
the broker and the workers, especially when using `Group` tasks: Each task in a
`Group` can potentially be the last one to finish, so each task has to retain a
Expand All @@ -173,9 +174,11 @@ There are a few things you can do to alleviate this issue:
- Consider breaking down large workflows into smaller partial workflows that
then schedule a subsequent workflow at the very end of the outermost `Chain`.

Lastly, you can use compression to reduce the size of the messages in your
queue. While dramatiq does not provide a compression implementation by default,
one can be added with just a few lines of code. For example:
#### Compression

You can use compression to reduce the size of the messages in your queue. While
dramatiq does not provide a compression implementation by default, one can be
added with just a few lines of code. For example:

```python
import dramatiq
Expand All @@ -197,6 +200,77 @@ class DramatiqLz4JSONEncoder(JSONEncoder):
dramatiq.set_encoder(DramatiqLz4JSONEncoder())
```

#### Callback Storage

To completely eliminate the issue of large workflows being stored in your
message queue, you can provide a custom callback storage backend to the
`WorkflowMiddleware`. A callback storage backend is responsible for storing and
retrieving the list of callbacks. For example, you could implement a storage
backend that stores the callbacks in S3 and only stores a reference to the S3
object in the message options.

A storage backend must implement the `CallbackStorage` interface:

```python
from typing import Any
from dramatiq_workflow import CallbackStorage, SerializedCompletionCallbacks

class MyS3Storage(CallbackStorage):
def store(self, callbacks: SerializedCompletionCallbacks) -> Any:
# ... store in S3 and return a key
pass

def retrieve(self, ref: Any) -> SerializedCompletionCallbacks:
# ... retrieve from S3 using the key
pass
```

Then, you can pass an instance of your custom storage backend to the
`WorkflowMiddleware`:

```python
from dramatiq.rate_limits.backends import RedisBackend
from dramatiq_workflow import WorkflowMiddleware

backend = RedisBackend()
storage = MyS3Storage() # Your custom storage backend
broker.add_middleware(WorkflowMiddleware(backend, callback_storage=storage))
```

##### Deduplicating Workflows

For convenience, `dramatiq-workflow` provides an abstract
`DedupWorkflowCallbackStorage` class that you can use to separate the storage
of workflows from the storage of callbacks. This is useful for deduplicating
large workflow definitions that may be part of multiple callbacks, especially
when chaining large groups of tasks.

To use it, you need to subclass `DedupWorkflowCallbackStorage` and implement
the `_store_workflow` and `_load_workflow` methods.

```python
from typing import Any
from dramatiq_workflow import DedupWorkflowCallbackStorage

class MyDedupStorage(DedupWorkflowCallbackStorage):
def __init__(self):
# In a real application, this would be a persistent storage like a
# database or a distributed cache so that workers and producers can
# both access it.
self.__workflow_storage = {}

def _store_workflow(self, id: str, workflow: dict) -> Any:
# Using the `id` (which is the completion ID) to deduplicate.
workflow_key = id
if workflow_key not in self.__workflow_storage:
self.__workflow_storage[workflow_key] = workflow
return workflow_key # Return a reference to the workflow.

def _load_workflow(self, id: str, ref: Any) -> dict:
# `ref` is what `_store_workflow` returned.
return self.__workflow_storage[ref]
```

### Barrier

`dramatiq-workflow` uses a barrier mechanism to keep track of the current state
Expand Down
18 changes: 17 additions & 1 deletion dramatiq_workflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
from ._base import Workflow
from ._middleware import WorkflowMiddleware
from ._models import Chain, Group, Message, WithDelay, WorkflowType
from ._models import (
Chain,
Group,
LazyWorkflow,
Message,
SerializedCompletionCallback,
SerializedCompletionCallbacks,
WithDelay,
WorkflowType,
)
from ._storage import CallbackStorage, DedupWorkflowCallbackStorage, InlineCallbackStorage

__all__ = [
"CallbackStorage",
"Chain",
"DedupWorkflowCallbackStorage",
"Group",
"InlineCallbackStorage",
"LazyWorkflow",
"Message",
"SerializedCompletionCallback",
"SerializedCompletionCallbacks",
"WithDelay",
"Workflow",
"WorkflowMiddleware",
Expand Down
6 changes: 6 additions & 0 deletions dramatiq_workflow/_barrier.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import logging

import dramatiq.rate_limits

logger = logging.getLogger(__name__)


class AtMostOnceBarrier(dramatiq.rate_limits.Barrier):
"""
Expand Down Expand Up @@ -33,6 +37,8 @@ def wait(self, *args, block=True, timeout=None):
released = super().wait(*args, block=False)
if released:
never_released = self.backend.incr(self.ran_key, 1, 0, self.ttl)
if not never_released:
logger.warning("Barrier %s release already recorded; ignoring subsequent release attempt", self.key)
return never_released

return False
8 changes: 7 additions & 1 deletion dramatiq_workflow/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._middleware import WorkflowMiddleware, workflow_noop
from ._models import Chain, Group, Message, SerializedCompletionCallbacks, WithDelay, WorkflowType
from ._serialize import serialize_workflow
from ._storage import CallbackStorage

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -167,7 +168,8 @@ def __schedule_noop(self, completion_callbacks: SerializedCompletionCallbacks):
def __augment_message(self, message: Message, completion_callbacks: SerializedCompletionCallbacks) -> Message:
options = {}
if completion_callbacks:
options = {OPTION_KEY_CALLBACKS: completion_callbacks}
callbacks_ref = self.__callback_storage.store(completion_callbacks)
options = {OPTION_KEY_CALLBACKS: callbacks_ref}

return message.copy(
# We reset the message timestamp to better represent the time the
Expand Down Expand Up @@ -200,6 +202,10 @@ def __rate_limiter_backend(self) -> dramatiq.rate_limits.RateLimiterBackend:
def __barrier_type(self) -> type[dramatiq.rate_limits.Barrier]:
return self.__middleware.barrier_type

@property
def __callback_storage(self) -> CallbackStorage:
return self.__middleware.callback_storage

def __create_barrier(self, count: int) -> str:
completion_uuid = str(uuid4())
completion_barrier = self.__barrier_type(self.__rate_limiter_backend, completion_uuid, ttl=CALLBACK_BARRIER_TTL)
Expand Down
17 changes: 10 additions & 7 deletions dramatiq_workflow/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ._helpers import workflow_with_completion_callbacks
from ._models import SerializedCompletionCallbacks
from ._serialize import unserialize_workflow
from ._storage import CallbackStorage, InlineCallbackStorage

logger = logging.getLogger(__name__)

Expand All @@ -17,9 +18,11 @@ def __init__(
self,
rate_limiter_backend: dramatiq.rate_limits.RateLimiterBackend,
barrier_type: type[dramatiq.rate_limits.Barrier] = AtMostOnceBarrier,
callback_storage: CallbackStorage | None = None,
):
self.rate_limiter_backend = rate_limiter_backend
self.barrier_type = barrier_type
self.callback_storage = callback_storage or InlineCallbackStorage()

def after_process_boot(self, broker: dramatiq.Broker):
broker.declare_actor(workflow_noop)
Expand All @@ -34,10 +37,11 @@ def after_process_message(
if message.failed:
return

completion_callbacks: SerializedCompletionCallbacks | None = message.options.get(OPTION_KEY_CALLBACKS)
if completion_callbacks is None:
callbacks_ref = message.options.get(OPTION_KEY_CALLBACKS)
if callbacks_ref is None:
return

completion_callbacks = self.callback_storage.retrieve(callbacks_ref)
self._process_completion_callbacks(broker, completion_callbacks)

def _process_completion_callbacks(
Expand All @@ -46,11 +50,10 @@ def _process_completion_callbacks(
# Go through the completion callbacks backwards until we hit the first non-completed barrier
while len(completion_callbacks) > 0:
completion_id, remaining_workflow, propagate = completion_callbacks[-1]
if completion_id is not None:
barrier = self.barrier_type(self.rate_limiter_backend, completion_id)
if not barrier.wait(block=False):
logger.debug("Barrier not completed: %s", completion_id)
break
barrier = self.barrier_type(self.rate_limiter_backend, completion_id)
if not barrier.wait(block=False):
logger.debug("Barrier not completed: %s", completion_id)
break

logger.debug("Barrier completed: %s", completion_id)
completion_callbacks.pop()
Expand Down
5 changes: 4 additions & 1 deletion dramatiq_workflow/_models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing

import dramatiq
import dramatiq.rate_limits

Expand Down Expand Up @@ -39,5 +41,6 @@ def __eq__(self, other):
Message = dramatiq.Message
WorkflowType = Message | Chain | Group | WithDelay

SerializedCompletionCallback = tuple[str | None, dict | None, bool]
LazyWorkflow = typing.Callable[[], dict]
SerializedCompletionCallback = tuple[str, dict | LazyWorkflow | None, bool]
SerializedCompletionCallbacks = list[SerializedCompletionCallback]
27 changes: 20 additions & 7 deletions dramatiq_workflow/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,31 @@ def unserialize_workflow(workflow: typing.Any) -> WorkflowType:
Return an unserialized version of the workflow that can be used to create
a Workflow instance.
"""
result = unserialize_workflow_or_none(workflow)
if result is None:
raise ValueError("Cannot unserialize a workflow that resolves to None")
return result


def unserialize_workflow_or_none(workflow: typing.Any) -> WorkflowType | None:
"""
Return an unserialized version of the workflow that can be used to create
a Workflow instance.
"""
if callable(workflow):
workflow = workflow()

if workflow is None:
raise ValueError("Cannot unserialize None")
return None

if not isinstance(workflow, dict):
raise TypeError(f"Unsupported data type: {type(workflow)}")

return _unserialize_workflow_from_dict(workflow)


def _unserialize_workflow_from_dict(workflow: dict) -> WorkflowType:
"""Helper to unserialize from a dict, assuming it's not None and is a dict."""
workflow_type = workflow.pop("__type__")
match workflow_type:
case "message":
Expand All @@ -69,9 +88,3 @@ def unserialize_workflow(workflow: typing.Any) -> WorkflowType:
)

raise TypeError(f"Unsupported workflow type: {workflow_type}")


def unserialize_workflow_or_none(workflow: typing.Any) -> WorkflowType | None:
if workflow is None:
return None
return unserialize_workflow(workflow)
Loading