diff --git a/README.md b/README.md index f493f63..8a97d0c 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,8 @@ Task 3 and will run 2 seconds after Task 2 finishes. Because of how `dramatiq-workflow` is implemented, each task in a workflow has to know about the remaining tasks in the workflow that could potentially run -after it. When a workflow has a large number of tasks, this can lead to an +after it. By default, this is stored alongside your messages in the message +queue. When a workflow has a large number of tasks, it can lead to an increase of memory usage in the broker and increased network traffic between the broker and the workers, especially when using `Group` tasks: Each task in a `Group` can potentially be the last one to finish, so each task has to retain a @@ -173,9 +174,11 @@ There are a few things you can do to alleviate this issue: - Consider breaking down large workflows into smaller partial workflows that then schedule a subsequent workflow at the very end of the outermost `Chain`. -Lastly, you can use compression to reduce the size of the messages in your -queue. While dramatiq does not provide a compression implementation by default, -one can be added with just a few lines of code. For example: +#### Compression + +You can use compression to reduce the size of the messages in your queue. While +dramatiq does not provide a compression implementation by default, one can be +added with just a few lines of code. For example: ```python import dramatiq @@ -197,6 +200,77 @@ class DramatiqLz4JSONEncoder(JSONEncoder): dramatiq.set_encoder(DramatiqLz4JSONEncoder()) ``` +#### Callback Storage + +To completely eliminate the issue of large workflows being stored in your +message queue, you can provide a custom callback storage backend to the +`WorkflowMiddleware`. A callback storage backend is responsible for storing and +retrieving the list of callbacks. For example, you could implement a storage +backend that stores the callbacks in S3 and only stores a reference to the S3 +object in the message options. + +A storage backend must implement the `CallbackStorage` interface: + +```python +from typing import Any +from dramatiq_workflow import CallbackStorage, SerializedCompletionCallbacks + +class MyS3Storage(CallbackStorage): + def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + # ... store in S3 and return a key + pass + + def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + # ... retrieve from S3 using the key + pass +``` + +Then, you can pass an instance of your custom storage backend to the +`WorkflowMiddleware`: + +```python +from dramatiq.rate_limits.backends import RedisBackend +from dramatiq_workflow import WorkflowMiddleware + +backend = RedisBackend() +storage = MyS3Storage() # Your custom storage backend +broker.add_middleware(WorkflowMiddleware(backend, callback_storage=storage)) +``` + +##### Deduplicating Workflows + +For convenience, `dramatiq-workflow` provides an abstract +`DedupWorkflowCallbackStorage` class that you can use to separate the storage +of workflows from the storage of callbacks. This is useful for deduplicating +large workflow definitions that may be part of multiple callbacks, especially +when chaining large groups of tasks. + +To use it, you need to subclass `DedupWorkflowCallbackStorage` and implement +the `_store_workflow` and `_load_workflow` methods. + +```python +from typing import Any +from dramatiq_workflow import DedupWorkflowCallbackStorage + +class MyDedupStorage(DedupWorkflowCallbackStorage): + def __init__(self): + # In a real application, this would be a persistent storage like a + # database or a distributed cache so that workers and producers can + # both access it. + self.__workflow_storage = {} + + def _store_workflow(self, id: str, workflow: dict) -> Any: + # Using the `id` (which is the completion ID) to deduplicate. + workflow_key = id + if workflow_key not in self.__workflow_storage: + self.__workflow_storage[workflow_key] = workflow + return workflow_key # Return a reference to the workflow. + + def _load_workflow(self, id: str, ref: Any) -> dict: + # `ref` is what `_store_workflow` returned. + return self.__workflow_storage[ref] +``` + ### Barrier `dramatiq-workflow` uses a barrier mechanism to keep track of the current state diff --git a/dramatiq_workflow/__init__.py b/dramatiq_workflow/__init__.py index 3f003c9..4f10d08 100644 --- a/dramatiq_workflow/__init__.py +++ b/dramatiq_workflow/__init__.py @@ -1,11 +1,27 @@ from ._base import Workflow from ._middleware import WorkflowMiddleware -from ._models import Chain, Group, Message, WithDelay, WorkflowType +from ._models import ( + Chain, + Group, + LazyWorkflow, + Message, + SerializedCompletionCallback, + SerializedCompletionCallbacks, + WithDelay, + WorkflowType, +) +from ._storage import CallbackStorage, DedupWorkflowCallbackStorage, InlineCallbackStorage __all__ = [ + "CallbackStorage", "Chain", + "DedupWorkflowCallbackStorage", "Group", + "InlineCallbackStorage", + "LazyWorkflow", "Message", + "SerializedCompletionCallback", + "SerializedCompletionCallbacks", "WithDelay", "Workflow", "WorkflowMiddleware", diff --git a/dramatiq_workflow/_barrier.py b/dramatiq_workflow/_barrier.py index a1377c3..b7ca619 100644 --- a/dramatiq_workflow/_barrier.py +++ b/dramatiq_workflow/_barrier.py @@ -1,5 +1,9 @@ +import logging + import dramatiq.rate_limits +logger = logging.getLogger(__name__) + class AtMostOnceBarrier(dramatiq.rate_limits.Barrier): """ @@ -33,6 +37,8 @@ def wait(self, *args, block=True, timeout=None): released = super().wait(*args, block=False) if released: never_released = self.backend.incr(self.ran_key, 1, 0, self.ttl) + if not never_released: + logger.warning("Barrier %s release already recorded; ignoring subsequent release attempt", self.key) return never_released return False diff --git a/dramatiq_workflow/_base.py b/dramatiq_workflow/_base.py index 2ad018a..b621f38 100644 --- a/dramatiq_workflow/_base.py +++ b/dramatiq_workflow/_base.py @@ -10,6 +10,7 @@ from ._middleware import WorkflowMiddleware, workflow_noop from ._models import Chain, Group, Message, SerializedCompletionCallbacks, WithDelay, WorkflowType from ._serialize import serialize_workflow +from ._storage import CallbackStorage logger = logging.getLogger(__name__) @@ -167,7 +168,8 @@ def __schedule_noop(self, completion_callbacks: SerializedCompletionCallbacks): def __augment_message(self, message: Message, completion_callbacks: SerializedCompletionCallbacks) -> Message: options = {} if completion_callbacks: - options = {OPTION_KEY_CALLBACKS: completion_callbacks} + callbacks_ref = self.__callback_storage.store(completion_callbacks) + options = {OPTION_KEY_CALLBACKS: callbacks_ref} return message.copy( # We reset the message timestamp to better represent the time the @@ -200,6 +202,10 @@ def __rate_limiter_backend(self) -> dramatiq.rate_limits.RateLimiterBackend: def __barrier_type(self) -> type[dramatiq.rate_limits.Barrier]: return self.__middleware.barrier_type + @property + def __callback_storage(self) -> CallbackStorage: + return self.__middleware.callback_storage + def __create_barrier(self, count: int) -> str: completion_uuid = str(uuid4()) completion_barrier = self.__barrier_type(self.__rate_limiter_backend, completion_uuid, ttl=CALLBACK_BARRIER_TTL) diff --git a/dramatiq_workflow/_middleware.py b/dramatiq_workflow/_middleware.py index 1f2eb5d..b8907fd 100644 --- a/dramatiq_workflow/_middleware.py +++ b/dramatiq_workflow/_middleware.py @@ -8,6 +8,7 @@ from ._helpers import workflow_with_completion_callbacks from ._models import SerializedCompletionCallbacks from ._serialize import unserialize_workflow +from ._storage import CallbackStorage, InlineCallbackStorage logger = logging.getLogger(__name__) @@ -17,9 +18,11 @@ def __init__( self, rate_limiter_backend: dramatiq.rate_limits.RateLimiterBackend, barrier_type: type[dramatiq.rate_limits.Barrier] = AtMostOnceBarrier, + callback_storage: CallbackStorage | None = None, ): self.rate_limiter_backend = rate_limiter_backend self.barrier_type = barrier_type + self.callback_storage = callback_storage or InlineCallbackStorage() def after_process_boot(self, broker: dramatiq.Broker): broker.declare_actor(workflow_noop) @@ -34,10 +37,11 @@ def after_process_message( if message.failed: return - completion_callbacks: SerializedCompletionCallbacks | None = message.options.get(OPTION_KEY_CALLBACKS) - if completion_callbacks is None: + callbacks_ref = message.options.get(OPTION_KEY_CALLBACKS) + if callbacks_ref is None: return + completion_callbacks = self.callback_storage.retrieve(callbacks_ref) self._process_completion_callbacks(broker, completion_callbacks) def _process_completion_callbacks( @@ -46,11 +50,10 @@ def _process_completion_callbacks( # Go through the completion callbacks backwards until we hit the first non-completed barrier while len(completion_callbacks) > 0: completion_id, remaining_workflow, propagate = completion_callbacks[-1] - if completion_id is not None: - barrier = self.barrier_type(self.rate_limiter_backend, completion_id) - if not barrier.wait(block=False): - logger.debug("Barrier not completed: %s", completion_id) - break + barrier = self.barrier_type(self.rate_limiter_backend, completion_id) + if not barrier.wait(block=False): + logger.debug("Barrier not completed: %s", completion_id) + break logger.debug("Barrier completed: %s", completion_id) completion_callbacks.pop() diff --git a/dramatiq_workflow/_models.py b/dramatiq_workflow/_models.py index 68b7cd9..03c26a1 100644 --- a/dramatiq_workflow/_models.py +++ b/dramatiq_workflow/_models.py @@ -1,3 +1,5 @@ +import typing + import dramatiq import dramatiq.rate_limits @@ -39,5 +41,6 @@ def __eq__(self, other): Message = dramatiq.Message WorkflowType = Message | Chain | Group | WithDelay -SerializedCompletionCallback = tuple[str | None, dict | None, bool] +LazyWorkflow = typing.Callable[[], dict] +SerializedCompletionCallback = tuple[str, dict | LazyWorkflow | None, bool] SerializedCompletionCallbacks = list[SerializedCompletionCallback] diff --git a/dramatiq_workflow/_serialize.py b/dramatiq_workflow/_serialize.py index 2d0ad89..60caec9 100644 --- a/dramatiq_workflow/_serialize.py +++ b/dramatiq_workflow/_serialize.py @@ -48,12 +48,31 @@ def unserialize_workflow(workflow: typing.Any) -> WorkflowType: Return an unserialized version of the workflow that can be used to create a Workflow instance. """ + result = unserialize_workflow_or_none(workflow) + if result is None: + raise ValueError("Cannot unserialize a workflow that resolves to None") + return result + + +def unserialize_workflow_or_none(workflow: typing.Any) -> WorkflowType | None: + """ + Return an unserialized version of the workflow that can be used to create + a Workflow instance. + """ + if callable(workflow): + workflow = workflow() + if workflow is None: - raise ValueError("Cannot unserialize None") + return None if not isinstance(workflow, dict): raise TypeError(f"Unsupported data type: {type(workflow)}") + return _unserialize_workflow_from_dict(workflow) + + +def _unserialize_workflow_from_dict(workflow: dict) -> WorkflowType: + """Helper to unserialize from a dict, assuming it's not None and is a dict.""" workflow_type = workflow.pop("__type__") match workflow_type: case "message": @@ -69,9 +88,3 @@ def unserialize_workflow(workflow: typing.Any) -> WorkflowType: ) raise TypeError(f"Unsupported workflow type: {workflow_type}") - - -def unserialize_workflow_or_none(workflow: typing.Any) -> WorkflowType | None: - if workflow is None: - return None - return unserialize_workflow(workflow) diff --git a/dramatiq_workflow/_storage.py b/dramatiq_workflow/_storage.py new file mode 100644 index 0000000..11d2fbe --- /dev/null +++ b/dramatiq_workflow/_storage.py @@ -0,0 +1,131 @@ +import abc +from functools import partial +from typing import Any, Generic, TypeVar, cast + +from ._models import LazyWorkflow, SerializedCompletionCallbacks + +CallbacksRefT = TypeVar("CallbacksRefT") + + +class CallbackStorage(abc.ABC, Generic[CallbacksRefT]): + """ + Abstract base class for callback storage backends. + """ + + @abc.abstractmethod + def store(self, callbacks: SerializedCompletionCallbacks) -> CallbacksRefT: + """ + Stores callbacks and returns a reference to them. + + This reference will be stored in the dramatiq message options. It must + be serializable by the broker's encoder (e.g. JSON). + + Args: + callbacks: The callbacks to store. + + Returns: + A serializable reference to the stored callbacks. + """ + raise NotImplementedError + + @abc.abstractmethod + def retrieve(self, ref: CallbacksRefT) -> SerializedCompletionCallbacks: + """ + Retrieves callbacks using a reference. + + Args: + ref: The reference to the callbacks, as returned by `store`. + + Returns: + The retrieved callbacks. + """ + raise NotImplementedError + + +class InlineCallbackStorage(CallbackStorage[SerializedCompletionCallbacks]): + """ + A storage backend that stores callbacks inline with the message. + This is the default storage backend. + """ + + def store(self, callbacks: SerializedCompletionCallbacks) -> SerializedCompletionCallbacks: + return callbacks + + def retrieve(self, ref: SerializedCompletionCallbacks) -> SerializedCompletionCallbacks: + return ref + + +WorkflowRefT = TypeVar("WorkflowRefT") +CompletionCallbacksWithWorkflowRef = list[tuple[str, WorkflowRefT | None, bool]] + + +class _LazyLoadedWorkflow(Generic[WorkflowRefT]): + def __init__(self, ref: Any, load_func: LazyWorkflow): + self.ref = ref + self.load_func = load_func + + def __call__(self) -> dict: + return self.load_func() + + def __str__(self): + return f"_LazyLoadedWorkflow({self.ref})" + + +class DedupWorkflowCallbackStorage(CallbackStorage[CompletionCallbacksWithWorkflowRef], abc.ABC, Generic[WorkflowRefT]): + """ + An abstract storage backend that separates storage of workflows from + callbacks, allowing for deduplication of workflows. + """ + + @abc.abstractmethod + def _store_workflow(self, id: str, workflow: dict) -> WorkflowRefT: + """ + Stores a workflow and returns a reference to it. The `id` can be used + to deduplicate workflows, and the `workflow` is the actual workflow to + store. The reference returned must be serializable by the broker's + encoder (e.g. JSON). + """ + raise NotImplementedError + + @abc.abstractmethod + def _load_workflow(self, id: str, ref: WorkflowRefT) -> dict: + """ + Loads a workflow using the deduplication ID and reference previously + returned by `store_workflow`. + """ + raise NotImplementedError + + def store(self, callbacks: SerializedCompletionCallbacks) -> CompletionCallbacksWithWorkflowRef[WorkflowRefT]: + """ + Stores callbacks, offloading workflow storage to `store_workflow`. + """ + new_callbacks: CompletionCallbacksWithWorkflowRef[WorkflowRefT] = [] + for completion_id, remaining_workflow, is_group in callbacks: + remaining_workflow_ref = None + if isinstance(remaining_workflow, _LazyLoadedWorkflow): + remaining_workflow_ref = cast(WorkflowRefT, remaining_workflow.ref) + elif isinstance(remaining_workflow, dict): + remaining_workflow_ref = self._store_workflow(completion_id, remaining_workflow) + elif remaining_workflow is not None: + raise TypeError( + "Unsupported workflow type: " + f"{type(remaining_workflow)}. Expected None, dict, or _LazyLoadedWorkflow." + ) + new_callbacks.append((completion_id, remaining_workflow_ref, is_group)) + + return new_callbacks + + def retrieve(self, ref: CompletionCallbacksWithWorkflowRef[WorkflowRefT]) -> SerializedCompletionCallbacks: + """ + Retrieves callbacks and prepares lazy loaders for workflows. + """ + new_callbacks: SerializedCompletionCallbacks = [] + for completion_id, workflow_ref, is_group in ref: + if workflow_ref is not None and not callable(workflow_ref): + workflow_ref = _LazyLoadedWorkflow( + ref=workflow_ref, + load_func=partial(self._load_workflow, completion_id, workflow_ref), + ) + new_callbacks.append((completion_id, workflow_ref, is_group)) + + return new_callbacks diff --git a/dramatiq_workflow/tests/test_middleware.py b/dramatiq_workflow/tests/test_middleware.py index 143f14f..6802757 100644 --- a/dramatiq_workflow/tests/test_middleware.py +++ b/dramatiq_workflow/tests/test_middleware.py @@ -1,4 +1,5 @@ import unittest +from typing import Any from unittest import mock import dramatiq @@ -8,7 +9,50 @@ from dramatiq_workflow import Chain, WorkflowMiddleware from dramatiq_workflow._barrier import AtMostOnceBarrier from dramatiq_workflow._constants import OPTION_KEY_CALLBACKS +from dramatiq_workflow._models import SerializedCompletionCallbacks from dramatiq_workflow._serialize import serialize_workflow +from dramatiq_workflow._storage import CallbackStorage + + +class MyLazyStorage(CallbackStorage): + def __init__(self): + self.workflows = {} + self.callbacks = {} + self.workflow_ref_counter = 0 + self.callback_ref_counter = 0 + self.retrieve_calls = [] + self.loaded_workflows = set() + + def _create_loader(self, ref: Any): + def loader() -> dict: + self.loaded_workflows.add(ref) + return self.retrieve_workflow(ref) + + return loader + + def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + new_callbacks = [] + for completion_id, remaining_workflow, is_group in callbacks: + if isinstance(remaining_workflow, dict): + ref = self.workflow_ref_counter + self.workflows[ref] = remaining_workflow + self.workflow_ref_counter += 1 + lazy_workflow = self._create_loader(ref) + new_callbacks.append((completion_id, lazy_workflow, is_group)) + else: + new_callbacks.append((completion_id, remaining_workflow, is_group)) + + ref = self.callback_ref_counter + self.callbacks[ref] = new_callbacks + self.callback_ref_counter += 1 + return ref + + def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + self.retrieve_calls.append(ref) + return self.callbacks[ref] + + def retrieve_workflow(self, ref: Any) -> dict: + return self.workflows[ref] class WorkflowMiddlewareTests(unittest.TestCase): @@ -71,7 +115,10 @@ def test_after_process_message_with_failed_message(self): @mock.patch("dramatiq_workflow._base.time.time") def test_after_process_message_with_workflow(self, mock_time): mock_time.return_value = 1337 - message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]}) + barrier_key = "barrier_1" + barrier = AtMostOnceBarrier(self.rate_limiter_backend, barrier_key) + barrier.create(1) + message = self._make_message({OPTION_KEY_CALLBACKS: [(barrier_key, self._create_serialized_workflow(), True)]}) self.middleware.after_process_message(self.broker, message) @@ -90,3 +137,38 @@ def test_after_process_message_with_barriered_workflow(self, mock_time): # Calling again, barrier should be completed now self.middleware.after_process_message(self.broker, message) self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None) + + @mock.patch("dramatiq_workflow._base.time.time") + def test_after_process_message_with_lazy_loaded_workflow(self, mock_time): + mock_time.return_value = 1337 + storage = MyLazyStorage() + self.middleware = WorkflowMiddleware(self.rate_limiter_backend, callback_storage=storage) + + # Create a workflow that will be lazy loaded + serialized_workflow = self._create_serialized_workflow() + callbacks = [("barrier_1", serialized_workflow, True)] + + # Store it, which will convert it to a lazy workflow + callbacks_ref = storage.store(callbacks) + + # The lazy workflow object is now inside storage.callbacks[callbacks_ref] + lazy_workflow_obj = storage.callbacks[callbacks_ref][0][1] + self.assertTrue(callable(lazy_workflow_obj)) + + workflow_ref = 0 + self.assertNotIn(workflow_ref, storage.loaded_workflows) + + # Set up barrier + barrier = AtMostOnceBarrier(self.rate_limiter_backend, "barrier_1") + barrier.create(1) + + # Create message and process it + message = self._make_message({OPTION_KEY_CALLBACKS: callbacks_ref}) + self.middleware.after_process_message(self.broker, message) + + # Assertions + self.assertEqual(len(storage.retrieve_calls), 1) + self.assertEqual(storage.retrieve_calls[0], callbacks_ref) + self.assertIn(workflow_ref, storage.loaded_workflows) + + self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None) diff --git a/dramatiq_workflow/tests/test_storage.py b/dramatiq_workflow/tests/test_storage.py new file mode 100644 index 0000000..13a2284 --- /dev/null +++ b/dramatiq_workflow/tests/test_storage.py @@ -0,0 +1,105 @@ +import unittest +from typing import Any + +from .._models import SerializedCompletionCallbacks +from .._storage import DedupWorkflowCallbackStorage + + +class MyDedupStorage(DedupWorkflowCallbackStorage): + def __init__(self): + self.workflows = {} + self.callbacks = {} + self.workflow_store_calls = [] + self.workflow_load_calls = [] + + def _store_workflow(self, id: str, workflow: dict) -> Any: + self.workflow_store_calls.append((id, workflow)) + ref = f"workflow-ref-{id}" + self.workflows[ref] = workflow + return ref + + def _load_workflow(self, id: str, ref: Any) -> dict: + self.workflow_load_calls.append((id, ref)) + return self.workflows[ref] + + +class DedupWorkflowCallbackStorageTests(unittest.TestCase): + def setUp(self): + self.storage = MyDedupStorage() + + def test_store_and_retrieve(self): + workflow_dict = {"__type__": "chain", "children": []} + callbacks: SerializedCompletionCallbacks = [ + ("id1", workflow_dict, False), + ("id2", None, True), + ] + + # Store callbacks + callbacks_ref = self.storage.store(callbacks) + self.assertEqual(callbacks_ref, [("id1", "workflow-ref-id1", False), ("id2", None, True)]) + + # Check what was stored + self.assertEqual(len(self.storage.workflow_store_calls), 1) + self.assertEqual(self.storage.workflow_store_calls[0], ("id1", workflow_dict)) + + # Retrieve callbacks + retrieved_callbacks = self.storage.retrieve(callbacks_ref) + self.assertEqual(len(retrieved_callbacks), 2) + + # Check first callback (with workflow) + id1, loader1, is_group1 = retrieved_callbacks[0] + self.assertEqual(id1, "id1") + self.assertFalse(is_group1) + self.assertTrue(callable(loader1)) + + # Check second callback (without workflow) + id2, loader2, is_group2 = retrieved_callbacks[1] + self.assertEqual(id2, "id2") + self.assertTrue(is_group2) + self.assertIsNone(loader2) + + # Load the workflow + self.assertEqual(len(self.storage.workflow_load_calls), 0) + assert callable(loader1) + loaded_workflow = loader1() + self.assertEqual(len(self.storage.workflow_load_calls), 1) + self.assertEqual(self.storage.workflow_load_calls[0], ("id1", "workflow-ref-id1")) + self.assertEqual(loaded_workflow, workflow_dict) + + def test_store_with_unsupported_workflow_type(self): + def unsupported_lazy_workflow(): + return {"__type__": "chain", "children": []} + + callbacks: SerializedCompletionCallbacks = [ + ("id1", unsupported_lazy_workflow, False), + ] + + with self.assertRaises(TypeError): + self.storage.store(callbacks) + + def test_store_with_already_lazy_loaded_workflow(self): + # This test ensures that when we store a workflow that has already been + # loaded and wrapped by the storage, we don't try to store it again, + # but instead use its reference. + + # 1. Store a workflow and retrieve it. + workflow_dict = {"__type__": "chain", "children": []} + callbacks1: SerializedCompletionCallbacks = [("id1", workflow_dict, False)] + callbacks_ref1 = self.storage.store(callbacks1) + retrieved_callbacks1 = self.storage.retrieve(callbacks_ref1) + _, lazy_loader, _ = retrieved_callbacks1[0] + + # We should have one call to store the workflow + self.assertEqual(len(self.storage.workflow_store_calls), 1) + self.assertTrue(callable(lazy_loader)) + + # 2. Now, create new callbacks using the lazy loader from the previous step + # and store them. + callbacks2: SerializedCompletionCallbacks = [("id2", lazy_loader, False)] + callbacks_ref2 = self.storage.store(callbacks2) + + # 3. _store_workflow should NOT have been called again. + self.assertEqual(len(self.storage.workflow_store_calls), 1) + + # 4. The new stored callbacks should contain the original workflow reference. + self.assertEqual(callbacks_ref2[0][1], callbacks_ref1[0][1]) diff --git a/dramatiq_workflow/tests/test_workflow.py b/dramatiq_workflow/tests/test_workflow.py index e82974b..8b20ee3 100644 --- a/dramatiq_workflow/tests/test_workflow.py +++ b/dramatiq_workflow/tests/test_workflow.py @@ -8,6 +8,16 @@ from .._serialize import serialize_workflow, unserialize_workflow +class MyLazyLoader: + def __init__(self, workflow: dict): + self._workflow = workflow + self.loaded = False + + def __call__(self) -> dict: + self.loaded = True + return self._workflow + + class WorkflowTests(unittest.TestCase): def setUp(self): self.rate_limiter_backend = mock.create_autospec(dramatiq.rate_limits.RateLimiterBackend, instance=True) @@ -359,6 +369,18 @@ def test_serialize_unserialize(self): unserialized = unserialize_workflow(serialized) self.assertEqual(workflow.workflow, unserialized) + def test_unserialize_lazy_workflow(self): + workflow = Chain(self.task.message()) + serialized = serialize_workflow(workflow) + self.assertIsNotNone(serialized) + + lazy_loader = MyLazyLoader(serialized) + self.assertFalse(lazy_loader.loaded) + + unserialized = unserialize_workflow(lazy_loader) + self.assertTrue(lazy_loader.loaded) + self.assertEqual(workflow, unserialized) + @mock.patch("dramatiq_workflow._base.time.time") def test_additive_delays(self, time_mock): time_mock.return_value = 1717526000.12 @@ -440,7 +462,7 @@ def test_nested_delays(self, time_mock): ) def test_middleware_is_cached(self): - workflow = Workflow(Chain(self.task.message(), self.task.message()), broker=self.broker) + workflow = Workflow(Chain(), broker=self.broker) # Access middleware property multiple times workflow.run()