From da81513fb27faaa090f30c0a2c6a916967264eb6 Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Tue, 17 Jun 2025 09:26:06 -0700 Subject: [PATCH 01/10] Allow custom callback store --- README.md | 73 +++++++++++++++++++-- dramatiq_workflow/__init__.py | 15 ++++- dramatiq_workflow/_base.py | 8 ++- dramatiq_workflow/_middleware.py | 19 +++--- dramatiq_workflow/_models.py | 2 +- dramatiq_workflow/_storage.py | 74 ++++++++++++++++++++++ dramatiq_workflow/tests/test_middleware.py | 47 +++++++++++++- dramatiq_workflow/tests/test_storage.py | 34 ++++++++++ dramatiq_workflow/tests/test_workflow.py | 70 ++++++++++++++++++++ 9 files changed, 326 insertions(+), 16 deletions(-) create mode 100644 dramatiq_workflow/_storage.py create mode 100644 dramatiq_workflow/tests/test_storage.py diff --git a/README.md b/README.md index f493f63..0ba0aab 100644 --- a/README.md +++ b/README.md @@ -157,7 +157,8 @@ Task 3 and will run 2 seconds after Task 2 finishes. Because of how `dramatiq-workflow` is implemented, each task in a workflow has to know about the remaining tasks in the workflow that could potentially run -after it. When a workflow has a large number of tasks, this can lead to an +after it. By default, this is stored alongside your messages in the message +queue. When a workflow has a large number of tasks, it can lead to an increase of memory usage in the broker and increased network traffic between the broker and the workers, especially when using `Group` tasks: Each task in a `Group` can potentially be the last one to finish, so each task has to retain a @@ -173,9 +174,11 @@ There are a few things you can do to alleviate this issue: - Consider breaking down large workflows into smaller partial workflows that then schedule a subsequent workflow at the very end of the outermost `Chain`. -Lastly, you can use compression to reduce the size of the messages in your -queue. While dramatiq does not provide a compression implementation by default, -one can be added with just a few lines of code. For example: +#### Compression + +You can use compression to reduce the size of the messages in your queue. While +dramatiq does not provide a compression implementation by default, one can be +added with just a few lines of code. For example: ```python import dramatiq @@ -197,6 +200,68 @@ class DramatiqLz4JSONEncoder(JSONEncoder): dramatiq.set_encoder(DramatiqLz4JSONEncoder()) ``` +#### Callback Storage + +To completely eliminate the issue of large workflows being stored in your +message queue, you can provide a custom callback storage backend to the +`WorkflowMiddleware`. A callback storage backend is responsible for storing and +retrieving the list of callbacks. For example, you could implement a storage +backend that stores the callbacks in S3 and only stores a reference to the S3 +object in the message options. + +A storage backend must implement the `CallbackStorage` interface: + +```python +from typing import Any +from dramatiq_workflow import CallbackStorage, SerializedCompletionCallbacks + +class MyS3Storage(CallbackStorage): + def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + # ... store in S3 and return a key + pass + + def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + # ... retrieve from S3 using the key + pass +``` + +Then, you can pass an instance of your custom storage backend to the +`WorkflowMiddleware`: + +```python +from dramatiq.rate_limits.backends import RedisBackend +from dramatiq_workflow import WorkflowMiddleware + +backend = RedisBackend() +storage = MyS3Storage() # Your custom storage backend +broker.add_middleware(WorkflowMiddleware(backend, callback_storage=storage)) +``` + +#### Deduplicating Callbacks + +For convenience, `CallbackStorage` provides a helper method `_determine_dedup_key` +that you can use to deduplicate callbacks for `Group` tasks. If the +deduplication key already exists, storing can be skipped. + +```python +from typing import Any +from dramatiq_workflow import CallbackStorage, SerializedCompletionCallbacks + +class MyDedupStorage(CallbackStorage): + def __init__(self): + self.__storage = MagicStorage() + + def store(self, callbacks: SerializedCompletionCallbacks) -> str: + dedup_key, _ = self._determine_dedup_key(callbacks) + if not self.__storage.exists(dedup_key): + self.__storage.set(dedup_key, callbacks) + return dedup_key + + def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + # ref will be the deduplication key from before + return self.__storage.get(ref) +``` + ### Barrier `dramatiq-workflow` uses a barrier mechanism to keep track of the current state diff --git a/dramatiq_workflow/__init__.py b/dramatiq_workflow/__init__.py index 3f003c9..925c9bc 100644 --- a/dramatiq_workflow/__init__.py +++ b/dramatiq_workflow/__init__.py @@ -1,11 +1,24 @@ from ._base import Workflow from ._middleware import WorkflowMiddleware -from ._models import Chain, Group, Message, WithDelay, WorkflowType +from ._models import ( + Chain, + Group, + Message, + SerializedCompletionCallback, + SerializedCompletionCallbacks, + WithDelay, + WorkflowType, +) +from ._storage import CallbackStorage, InlineCallbackStorage __all__ = [ + "CallbackStorage", "Chain", "Group", + "InlineCallbackStorage", "Message", + "SerializedCompletionCallback", + "SerializedCompletionCallbacks", "WithDelay", "Workflow", "WorkflowMiddleware", diff --git a/dramatiq_workflow/_base.py b/dramatiq_workflow/_base.py index c5d7e2c..a14ecd6 100644 --- a/dramatiq_workflow/_base.py +++ b/dramatiq_workflow/_base.py @@ -10,6 +10,7 @@ from ._middleware import WorkflowMiddleware, workflow_noop from ._models import Chain, Group, Message, SerializedCompletionCallbacks, WithDelay, WorkflowType from ._serialize import serialize_workflow +from ._storage import CallbackStorage logger = logging.getLogger(__name__) @@ -152,7 +153,8 @@ def __schedule_noop(self, completion_callbacks: SerializedCompletionCallbacks): def __augment_message(self, message: Message, completion_callbacks: SerializedCompletionCallbacks) -> Message: options = {} if completion_callbacks: - options = {OPTION_KEY_CALLBACKS: completion_callbacks} + callbacks_ref = self.__callback_storage.store(completion_callbacks) + options = {OPTION_KEY_CALLBACKS: callbacks_ref} return message.copy( # We reset the message timestamp to better represent the time the @@ -185,6 +187,10 @@ def __rate_limiter_backend(self) -> dramatiq.rate_limits.RateLimiterBackend: def __barrier_type(self) -> type[dramatiq.rate_limits.Barrier]: return self.__middleware.barrier_type + @property + def __callback_storage(self) -> CallbackStorage: + return self.__middleware.callback_storage + def __create_barrier(self, count: int) -> str: completion_uuid = str(uuid4()) completion_barrier = self.__barrier_type(self.__rate_limiter_backend, completion_uuid, ttl=CALLBACK_BARRIER_TTL) diff --git a/dramatiq_workflow/_middleware.py b/dramatiq_workflow/_middleware.py index f751610..a4bafe4 100644 --- a/dramatiq_workflow/_middleware.py +++ b/dramatiq_workflow/_middleware.py @@ -6,8 +6,8 @@ from ._barrier import AtMostOnceBarrier from ._constants import OPTION_KEY_CALLBACKS from ._helpers import workflow_with_completion_callbacks -from ._models import SerializedCompletionCallbacks from ._serialize import unserialize_workflow +from ._storage import CallbackStorage, InlineCallbackStorage logger = logging.getLogger(__name__) @@ -17,9 +17,11 @@ def __init__( self, rate_limiter_backend: dramatiq.rate_limits.RateLimiterBackend, barrier_type: type[dramatiq.rate_limits.Barrier] = AtMostOnceBarrier, + callback_storage: CallbackStorage | None = None, ): self.rate_limiter_backend = rate_limiter_backend self.barrier_type = barrier_type + self.callback_storage = callback_storage or InlineCallbackStorage() def after_process_boot(self, broker: dramatiq.Broker): broker.declare_actor(workflow_noop) @@ -34,18 +36,19 @@ def after_process_message( if message.failed: return - completion_callbacks: SerializedCompletionCallbacks | None = message.options.get(OPTION_KEY_CALLBACKS) - if completion_callbacks is None: + callbacks_ref = message.options.get(OPTION_KEY_CALLBACKS) + if callbacks_ref is None: return + completion_callbacks = self.callback_storage.retrieve(callbacks_ref) + # Go through the completion callbacks backwards until we hit the first non-completed barrier while len(completion_callbacks) > 0: completion_id, remaining_workflow, propagate = completion_callbacks[-1] - if completion_id is not None: - barrier = self.barrier_type(self.rate_limiter_backend, completion_id) - if not barrier.wait(block=False): - logger.debug("Barrier not completed: %s", completion_id) - break + barrier = self.barrier_type(self.rate_limiter_backend, completion_id) + if not barrier.wait(block=False): + logger.debug("Barrier not completed: %s", completion_id) + break logger.debug("Barrier completed: %s", completion_id) completion_callbacks.pop() diff --git a/dramatiq_workflow/_models.py b/dramatiq_workflow/_models.py index 68b7cd9..59a0be1 100644 --- a/dramatiq_workflow/_models.py +++ b/dramatiq_workflow/_models.py @@ -39,5 +39,5 @@ def __eq__(self, other): Message = dramatiq.Message WorkflowType = Message | Chain | Group | WithDelay -SerializedCompletionCallback = tuple[str | None, dict | None, bool] +SerializedCompletionCallback = tuple[str, dict | None, bool] SerializedCompletionCallbacks = list[SerializedCompletionCallback] diff --git a/dramatiq_workflow/_storage.py b/dramatiq_workflow/_storage.py new file mode 100644 index 0000000..b12d24f --- /dev/null +++ b/dramatiq_workflow/_storage.py @@ -0,0 +1,74 @@ +import abc +from typing import Any + +from ._models import SerializedCompletionCallbacks + + +class CallbackStorage(abc.ABC): + """ + Abstract base class for callback storage backends. + """ + + @abc.abstractmethod + def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + """ + Stores callbacks and returns a reference to them. + + This reference will be stored in the dramatiq message options. It must + be serializable by the broker's encoder (e.g. JSON). + + Args: + callbacks: The callbacks to store. + + Returns: + A serializable reference to the stored callbacks. + """ + raise NotImplementedError + + @abc.abstractmethod + def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + """ + Retrieves callbacks using a reference. + + Args: + ref: The reference to the callbacks, as returned by `store`. + + Returns: + The retrieved callbacks. + """ + raise NotImplementedError + + def _determine_dedup_key(self, callbacks: SerializedCompletionCallbacks) -> tuple[str, bool]: + """ + Determines a deduplication key for the given callbacks. + + This is used by deduplication storage backends to identify unique + callback sets. + + Returns: + A tuple containing the completion ID and a boolean indicating if + the callbacks are part of a group (i.e. if deduplication is + strictly needed). + """ + + # NOTE: `Workflow.__augment_message` only calls the `CallbackStorage` + # when the `callbacks` list is not empty. This `assert` should always + # hold true. + assert isinstance(callbacks, list) and len(callbacks) > 0, "Callbacks must be a non-empty list" + + last_callback = callbacks[-1] + completion_id, _, is_group = last_callback + return completion_id, is_group + + +class InlineCallbackStorage(CallbackStorage): + """ + A storage backend that stores callbacks inline with the message. + This is the default storage backend. + """ + + def store(self, callbacks: SerializedCompletionCallbacks) -> SerializedCompletionCallbacks: + return callbacks + + def retrieve(self, ref: SerializedCompletionCallbacks) -> SerializedCompletionCallbacks: + return ref diff --git a/dramatiq_workflow/tests/test_middleware.py b/dramatiq_workflow/tests/test_middleware.py index 143f14f..ca42781 100644 --- a/dramatiq_workflow/tests/test_middleware.py +++ b/dramatiq_workflow/tests/test_middleware.py @@ -1,4 +1,5 @@ import unittest +from typing import Any from unittest import mock import dramatiq @@ -8,7 +9,25 @@ from dramatiq_workflow import Chain, WorkflowMiddleware from dramatiq_workflow._barrier import AtMostOnceBarrier from dramatiq_workflow._constants import OPTION_KEY_CALLBACKS +from dramatiq_workflow._models import SerializedCompletionCallbacks from dramatiq_workflow._serialize import serialize_workflow +from dramatiq_workflow._storage import CallbackStorage + + +class MyDedupStorage(CallbackStorage): + def __init__(self): + self.storage = {} + self.retrieve_calls = [] + + def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + dedup_key, _ = self._determine_dedup_key(callbacks) + if dedup_key not in self.storage: + self.storage[dedup_key] = callbacks + return dedup_key + + def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + self.retrieve_calls.append(ref) + return self.storage[ref] class WorkflowMiddlewareTests(unittest.TestCase): @@ -71,7 +90,10 @@ def test_after_process_message_with_failed_message(self): @mock.patch("dramatiq_workflow._base.time.time") def test_after_process_message_with_workflow(self, mock_time): mock_time.return_value = 1337 - message = self._make_message({OPTION_KEY_CALLBACKS: [(None, self._create_serialized_workflow(), True)]}) + barrier_key = "barrier_1" + barrier = AtMostOnceBarrier(self.rate_limiter_backend, barrier_key) + barrier.create(1) + message = self._make_message({OPTION_KEY_CALLBACKS: [(barrier_key, self._create_serialized_workflow(), True)]}) self.middleware.after_process_message(self.broker, message) @@ -90,3 +112,26 @@ def test_after_process_message_with_barriered_workflow(self, mock_time): # Calling again, barrier should be completed now self.middleware.after_process_message(self.broker, message) self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None) + + @mock.patch("dramatiq_workflow._base.time.time") + def test_after_process_message_with_custom_storage(self, mock_time): + mock_time.return_value = 1337 + storage = MyDedupStorage() + self.middleware = WorkflowMiddleware(self.rate_limiter_backend, callback_storage=storage) + + serialized_workflow = self._create_serialized_workflow() + callbacks = [("barrier_1", serialized_workflow, True)] + + barrier = AtMostOnceBarrier(self.rate_limiter_backend, "barrier_1") + barrier.create(1) + + dedup_key = storage.store(callbacks) + + message = self._make_message({OPTION_KEY_CALLBACKS: dedup_key}) + + self.middleware.after_process_message(self.broker, message) + + self.assertEqual(len(storage.retrieve_calls), 1) + self.assertEqual(storage.retrieve_calls[0], dedup_key) + + self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None) diff --git a/dramatiq_workflow/tests/test_storage.py b/dramatiq_workflow/tests/test_storage.py new file mode 100644 index 0000000..3ad06c9 --- /dev/null +++ b/dramatiq_workflow/tests/test_storage.py @@ -0,0 +1,34 @@ +import unittest + +from .._models import SerializedCompletionCallbacks +from .._storage import InlineCallbackStorage + + +class CallbackStorageTests(unittest.TestCase): + def setUp(self): + self.storage = InlineCallbackStorage() + + def test_determine_dedup_key_for_group(self): + callbacks: SerializedCompletionCallbacks = [("group-id-123", None, True)] + dedup_key, is_group = self.storage._determine_dedup_key(callbacks) + self.assertEqual(dedup_key, "group-id-123") + self.assertTrue(is_group) + + def test_determine_dedup_key_for_chain(self): + callbacks: SerializedCompletionCallbacks = [("chain-id-456", {"__type__": "chain"}, False)] + dedup_key, is_group = self.storage._determine_dedup_key(callbacks) + self.assertEqual(dedup_key, "chain-id-456") + self.assertFalse(is_group) + + def test_determine_dedup_key_for_nested_callbacks(self): + callbacks: SerializedCompletionCallbacks = [ + ("chain-id-456", {"__type__": "chain"}, False), + ("group-id-123", None, True), + ] + dedup_key, is_group = self.storage._determine_dedup_key(callbacks) + self.assertEqual(dedup_key, "group-id-123") + self.assertTrue(is_group) + + def test_determine_dedup_key_with_empty_list_raises_assertion_error(self): + with self.assertRaises(AssertionError): + self.storage._determine_dedup_key([]) diff --git a/dramatiq_workflow/tests/test_workflow.py b/dramatiq_workflow/tests/test_workflow.py index 3f8f71d..099cfe8 100644 --- a/dramatiq_workflow/tests/test_workflow.py +++ b/dramatiq_workflow/tests/test_workflow.py @@ -1,11 +1,31 @@ import unittest +from typing import Any from unittest import mock import dramatiq import dramatiq.rate_limits from .. import Chain, Group, WithDelay, Workflow, WorkflowMiddleware +from .._constants import OPTION_KEY_CALLBACKS +from .._models import SerializedCompletionCallbacks from .._serialize import serialize_workflow, unserialize_workflow +from .._storage import CallbackStorage + + +class MyDedupStorage(CallbackStorage): + def __init__(self): + self.storage = {} + self.store_calls = [] + + def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + dedup_key, is_group = self._determine_dedup_key(callbacks) + self.store_calls.append((dedup_key, callbacks, is_group)) + if dedup_key not in self.storage: + self.storage[dedup_key] = callbacks + return dedup_key + + def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + return self.storage[ref] class WorkflowTests(unittest.TestCase): @@ -389,3 +409,53 @@ def test_nested_delays(self, time_mock): ), delay=20, ) + + @mock.patch("dramatiq_workflow._base.time.time") + def test_workflow_with_custom_storage(self, time_mock): + time_mock.return_value = 1717526000.12 + + storage = MyDedupStorage() + # The broker is a mock object, we can just replace the middleware list + self.broker.middleware = [ + WorkflowMiddleware( + rate_limiter_backend=self.rate_limiter_backend, + barrier_type=self.barrier, + callback_storage=storage, + ) + ] + + workflow = Workflow(Group(self.task.message(), self.task.message()), broker=self.broker) + workflow.run() + + # Assertions + self.assertEqual(len(storage.store_calls), 2) + + dedup_key1, callbacks1, is_group1 = storage.store_calls[0] + dedup_key2, callbacks2, is_group2 = storage.store_calls[1] + self.assertEqual(dedup_key1, dedup_key2) + self.assertEqual(callbacks1, callbacks2) + self.assertTrue(is_group1) + self.assertTrue(is_group2) + + self.assertEqual(len(storage.storage), 1) + self.assertIn(dedup_key1, storage.storage) + + self.broker.enqueue.assert_has_calls( + [ + mock.call(mock.ANY, delay=None), + mock.call(mock.ANY, delay=None), + ], + any_order=True, + ) + + # Check the options passed to enqueue + self.assertEqual(len(self.broker.enqueue.call_args_list), 2) + message1 = self.broker.enqueue.call_args_list[0][0][0] + delay1 = self.broker.enqueue.call_args_list[0][1]["delay"] + message2 = self.broker.enqueue.call_args_list[1][0][0] + delay2 = self.broker.enqueue.call_args_list[1][1]["delay"] + + self.assertEqual(message1.options[OPTION_KEY_CALLBACKS], dedup_key1) + self.assertEqual(delay1, None) + self.assertEqual(message2.options[OPTION_KEY_CALLBACKS], dedup_key2) + self.assertEqual(delay2, None) From e63498c449930c010ee2607637d8028f1cf39469 Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Tue, 17 Jun 2025 10:31:29 -0700 Subject: [PATCH 02/10] Unrelated: Add some logging when barrier is released too many times --- dramatiq_workflow/_barrier.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dramatiq_workflow/_barrier.py b/dramatiq_workflow/_barrier.py index a1377c3..b7ca619 100644 --- a/dramatiq_workflow/_barrier.py +++ b/dramatiq_workflow/_barrier.py @@ -1,5 +1,9 @@ +import logging + import dramatiq.rate_limits +logger = logging.getLogger(__name__) + class AtMostOnceBarrier(dramatiq.rate_limits.Barrier): """ @@ -33,6 +37,8 @@ def wait(self, *args, block=True, timeout=None): released = super().wait(*args, block=False) if released: never_released = self.backend.incr(self.ran_key, 1, 0, self.ttl) + if not never_released: + logger.warning("Barrier %s release already recorded; ignoring subsequent release attempt", self.key) return never_released return False From 9a7135106b8f9e126ab9533942588052b78c0a9a Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Tue, 17 Jun 2025 19:16:37 -0700 Subject: [PATCH 03/10] Allow lazy loaded workflows to be returned by storage --- dramatiq_workflow/__init__.py | 2 + dramatiq_workflow/_models.py | 13 +++- dramatiq_workflow/_serialize.py | 28 ++++++-- dramatiq_workflow/tests/test_middleware.py | 79 +++++++++++++++++++++- dramatiq_workflow/tests/test_workflow.py | 24 ++++++- 5 files changed, 136 insertions(+), 10 deletions(-) diff --git a/dramatiq_workflow/__init__.py b/dramatiq_workflow/__init__.py index 925c9bc..613689a 100644 --- a/dramatiq_workflow/__init__.py +++ b/dramatiq_workflow/__init__.py @@ -3,6 +3,7 @@ from ._models import ( Chain, Group, + LazyWorkflow, Message, SerializedCompletionCallback, SerializedCompletionCallbacks, @@ -16,6 +17,7 @@ "Chain", "Group", "InlineCallbackStorage", + "LazyWorkflow", "Message", "SerializedCompletionCallback", "SerializedCompletionCallbacks", diff --git a/dramatiq_workflow/_models.py b/dramatiq_workflow/_models.py index 59a0be1..cedc945 100644 --- a/dramatiq_workflow/_models.py +++ b/dramatiq_workflow/_models.py @@ -1,7 +1,18 @@ +import abc + import dramatiq import dramatiq.rate_limits +class LazyWorkflow(abc.ABC): + """Abstract base class for lazily-loaded workflows.""" + + @abc.abstractmethod + def load(self) -> dict: + """Loads the workflow representation.""" + raise NotImplementedError + + class Chain: def __init__(self, *tasks: "WorkflowType"): self.tasks = list(tasks) @@ -39,5 +50,5 @@ def __eq__(self, other): Message = dramatiq.Message WorkflowType = Message | Chain | Group | WithDelay -SerializedCompletionCallback = tuple[str, dict | None, bool] +SerializedCompletionCallback = tuple[str, dict | LazyWorkflow | None, bool] SerializedCompletionCallbacks = list[SerializedCompletionCallback] diff --git a/dramatiq_workflow/_serialize.py b/dramatiq_workflow/_serialize.py index 2d0ad89..f03da70 100644 --- a/dramatiq_workflow/_serialize.py +++ b/dramatiq_workflow/_serialize.py @@ -3,6 +3,7 @@ from ._models import ( Chain, Group, + LazyWorkflow, Message, WithDelay, WorkflowType, @@ -48,12 +49,31 @@ def unserialize_workflow(workflow: typing.Any) -> WorkflowType: Return an unserialized version of the workflow that can be used to create a Workflow instance. """ + result = unserialize_workflow_or_none(workflow) + if result is None: + raise ValueError("Cannot unserialize a workflow that resolves to None") + return result + + +def unserialize_workflow_or_none(workflow: typing.Any) -> WorkflowType | None: + """ + Return an unserialized version of the workflow that can be used to create + a Workflow instance. + """ + if isinstance(workflow, LazyWorkflow): + workflow = workflow.load() + if workflow is None: - raise ValueError("Cannot unserialize None") + return None if not isinstance(workflow, dict): raise TypeError(f"Unsupported data type: {type(workflow)}") + return _unserialize_workflow_from_dict(workflow) + + +def _unserialize_workflow_from_dict(workflow: dict) -> WorkflowType: + """Helper to unserialize from a dict, assuming it's not None and is a dict.""" workflow_type = workflow.pop("__type__") match workflow_type: case "message": @@ -69,9 +89,3 @@ def unserialize_workflow(workflow: typing.Any) -> WorkflowType: ) raise TypeError(f"Unsupported workflow type: {workflow_type}") - - -def unserialize_workflow_or_none(workflow: typing.Any) -> WorkflowType | None: - if workflow is None: - return None - return unserialize_workflow(workflow) diff --git a/dramatiq_workflow/tests/test_middleware.py b/dramatiq_workflow/tests/test_middleware.py index ca42781..e33f3bf 100644 --- a/dramatiq_workflow/tests/test_middleware.py +++ b/dramatiq_workflow/tests/test_middleware.py @@ -9,7 +9,7 @@ from dramatiq_workflow import Chain, WorkflowMiddleware from dramatiq_workflow._barrier import AtMostOnceBarrier from dramatiq_workflow._constants import OPTION_KEY_CALLBACKS -from dramatiq_workflow._models import SerializedCompletionCallbacks +from dramatiq_workflow._models import LazyWorkflow, SerializedCompletionCallbacks from dramatiq_workflow._serialize import serialize_workflow from dramatiq_workflow._storage import CallbackStorage @@ -30,6 +30,50 @@ def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: return self.storage[ref] +class MyLazyWorkflow(LazyWorkflow): + def __init__(self, storage: "MyLazyStorage", ref: Any): + self._storage = storage + self._ref = ref + self.loaded = False + + def load(self) -> dict: + self.loaded = True + return self._storage.retrieve_workflow(self._ref) + + +class MyLazyStorage(CallbackStorage): + def __init__(self): + self.workflows = {} + self.callbacks = {} + self.workflow_ref_counter = 0 + self.callback_ref_counter = 0 + self.retrieve_calls = [] + + def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + new_callbacks = [] + for completion_id, remaining_workflow, is_group in callbacks: + if isinstance(remaining_workflow, dict): + ref = self.workflow_ref_counter + self.workflows[ref] = remaining_workflow + self.workflow_ref_counter += 1 + lazy_workflow = MyLazyWorkflow(self, ref) + new_callbacks.append((completion_id, lazy_workflow, is_group)) + else: + new_callbacks.append((completion_id, remaining_workflow, is_group)) + + ref = self.callback_ref_counter + self.callbacks[ref] = new_callbacks + self.callback_ref_counter += 1 + return ref + + def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + self.retrieve_calls.append(ref) + return self.callbacks[ref] + + def retrieve_workflow(self, ref: Any) -> dict: + return self.workflows[ref] + + class WorkflowMiddlewareTests(unittest.TestCase): def setUp(self): # Initialize common mocks and the middleware instance for each test @@ -135,3 +179,36 @@ def test_after_process_message_with_custom_storage(self, mock_time): self.assertEqual(storage.retrieve_calls[0], dedup_key) self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None) + + @mock.patch("dramatiq_workflow._base.time.time") + def test_after_process_message_with_lazy_loaded_workflow(self, mock_time): + mock_time.return_value = 1337 + storage = MyLazyStorage() + self.middleware = WorkflowMiddleware(self.rate_limiter_backend, callback_storage=storage) + + # Create a workflow that will be lazy loaded + serialized_workflow = self._create_serialized_workflow() + callbacks = [("barrier_1", serialized_workflow, True)] + + # Store it, which will convert it to a lazy workflow + callbacks_ref = storage.store(callbacks) + + # The lazy workflow object is now inside storage.callbacks[callbacks_ref] + lazy_workflow_obj = storage.callbacks[callbacks_ref][0][1] + self.assertIsInstance(lazy_workflow_obj, MyLazyWorkflow) + self.assertFalse(lazy_workflow_obj.loaded) + + # Set up barrier + barrier = AtMostOnceBarrier(self.rate_limiter_backend, "barrier_1") + barrier.create(1) + + # Create message and process it + message = self._make_message({OPTION_KEY_CALLBACKS: callbacks_ref}) + self.middleware.after_process_message(self.broker, message) + + # Assertions + self.assertEqual(len(storage.retrieve_calls), 1) + self.assertEqual(storage.retrieve_calls[0], callbacks_ref) + self.assertTrue(lazy_workflow_obj.loaded) + + self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None) diff --git a/dramatiq_workflow/tests/test_workflow.py b/dramatiq_workflow/tests/test_workflow.py index 099cfe8..cd0fab8 100644 --- a/dramatiq_workflow/tests/test_workflow.py +++ b/dramatiq_workflow/tests/test_workflow.py @@ -7,7 +7,7 @@ from .. import Chain, Group, WithDelay, Workflow, WorkflowMiddleware from .._constants import OPTION_KEY_CALLBACKS -from .._models import SerializedCompletionCallbacks +from .._models import LazyWorkflow, SerializedCompletionCallbacks from .._serialize import serialize_workflow, unserialize_workflow from .._storage import CallbackStorage @@ -28,6 +28,16 @@ def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: return self.storage[ref] +class MyLazyWorkflow(LazyWorkflow): + def __init__(self, workflow: dict): + self._workflow = workflow + self.loaded = False + + def load(self) -> dict: + self.loaded = True + return self._workflow + + class WorkflowTests(unittest.TestCase): def setUp(self): self.rate_limiter_backend = mock.create_autospec(dramatiq.rate_limits.RateLimiterBackend, instance=True) @@ -330,6 +340,18 @@ def test_serialize_unserialize(self): unserialized = unserialize_workflow(serialized) self.assertEqual(workflow.workflow, unserialized) + def test_unserialize_lazy_workflow(self): + workflow = Chain(self.task.message()) + serialized = serialize_workflow(workflow) + self.assertIsNotNone(serialized) + + lazy_workflow = MyLazyWorkflow(serialized) + self.assertFalse(lazy_workflow.loaded) + + unserialized = unserialize_workflow(lazy_workflow) + self.assertTrue(lazy_workflow.loaded) + self.assertEqual(workflow, unserialized) + @mock.patch("dramatiq_workflow._base.time.time") def test_additive_delays(self, time_mock): time_mock.return_value = 1717526000.12 From 756ff661e0c9b57456db6b1e0bc39254ccdae3de Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Wed, 18 Jun 2025 09:00:19 -0700 Subject: [PATCH 04/10] Simplify lazy workflow interface --- dramatiq_workflow/_models.py | 12 ++------- dramatiq_workflow/_serialize.py | 5 ++-- dramatiq_workflow/tests/test_middleware.py | 31 +++++++++++----------- dramatiq_workflow/tests/test_workflow.py | 14 +++++----- 4 files changed, 26 insertions(+), 36 deletions(-) diff --git a/dramatiq_workflow/_models.py b/dramatiq_workflow/_models.py index cedc945..03c26a1 100644 --- a/dramatiq_workflow/_models.py +++ b/dramatiq_workflow/_models.py @@ -1,18 +1,9 @@ -import abc +import typing import dramatiq import dramatiq.rate_limits -class LazyWorkflow(abc.ABC): - """Abstract base class for lazily-loaded workflows.""" - - @abc.abstractmethod - def load(self) -> dict: - """Loads the workflow representation.""" - raise NotImplementedError - - class Chain: def __init__(self, *tasks: "WorkflowType"): self.tasks = list(tasks) @@ -50,5 +41,6 @@ def __eq__(self, other): Message = dramatiq.Message WorkflowType = Message | Chain | Group | WithDelay +LazyWorkflow = typing.Callable[[], dict] SerializedCompletionCallback = tuple[str, dict | LazyWorkflow | None, bool] SerializedCompletionCallbacks = list[SerializedCompletionCallback] diff --git a/dramatiq_workflow/_serialize.py b/dramatiq_workflow/_serialize.py index f03da70..60caec9 100644 --- a/dramatiq_workflow/_serialize.py +++ b/dramatiq_workflow/_serialize.py @@ -3,7 +3,6 @@ from ._models import ( Chain, Group, - LazyWorkflow, Message, WithDelay, WorkflowType, @@ -60,8 +59,8 @@ def unserialize_workflow_or_none(workflow: typing.Any) -> WorkflowType | None: Return an unserialized version of the workflow that can be used to create a Workflow instance. """ - if isinstance(workflow, LazyWorkflow): - workflow = workflow.load() + if callable(workflow): + workflow = workflow() if workflow is None: return None diff --git a/dramatiq_workflow/tests/test_middleware.py b/dramatiq_workflow/tests/test_middleware.py index e33f3bf..0484705 100644 --- a/dramatiq_workflow/tests/test_middleware.py +++ b/dramatiq_workflow/tests/test_middleware.py @@ -9,7 +9,7 @@ from dramatiq_workflow import Chain, WorkflowMiddleware from dramatiq_workflow._barrier import AtMostOnceBarrier from dramatiq_workflow._constants import OPTION_KEY_CALLBACKS -from dramatiq_workflow._models import LazyWorkflow, SerializedCompletionCallbacks +from dramatiq_workflow._models import SerializedCompletionCallbacks from dramatiq_workflow._serialize import serialize_workflow from dramatiq_workflow._storage import CallbackStorage @@ -30,17 +30,6 @@ def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: return self.storage[ref] -class MyLazyWorkflow(LazyWorkflow): - def __init__(self, storage: "MyLazyStorage", ref: Any): - self._storage = storage - self._ref = ref - self.loaded = False - - def load(self) -> dict: - self.loaded = True - return self._storage.retrieve_workflow(self._ref) - - class MyLazyStorage(CallbackStorage): def __init__(self): self.workflows = {} @@ -48,6 +37,14 @@ def __init__(self): self.workflow_ref_counter = 0 self.callback_ref_counter = 0 self.retrieve_calls = [] + self.loaded_workflows = set() + + def _create_loader(self, ref: Any): + def loader() -> dict: + self.loaded_workflows.add(ref) + return self.retrieve_workflow(ref) + + return loader def store(self, callbacks: SerializedCompletionCallbacks) -> Any: new_callbacks = [] @@ -56,7 +53,7 @@ def store(self, callbacks: SerializedCompletionCallbacks) -> Any: ref = self.workflow_ref_counter self.workflows[ref] = remaining_workflow self.workflow_ref_counter += 1 - lazy_workflow = MyLazyWorkflow(self, ref) + lazy_workflow = self._create_loader(ref) new_callbacks.append((completion_id, lazy_workflow, is_group)) else: new_callbacks.append((completion_id, remaining_workflow, is_group)) @@ -195,8 +192,10 @@ def test_after_process_message_with_lazy_loaded_workflow(self, mock_time): # The lazy workflow object is now inside storage.callbacks[callbacks_ref] lazy_workflow_obj = storage.callbacks[callbacks_ref][0][1] - self.assertIsInstance(lazy_workflow_obj, MyLazyWorkflow) - self.assertFalse(lazy_workflow_obj.loaded) + self.assertTrue(callable(lazy_workflow_obj)) + + workflow_ref = 0 + self.assertNotIn(workflow_ref, storage.loaded_workflows) # Set up barrier barrier = AtMostOnceBarrier(self.rate_limiter_backend, "barrier_1") @@ -209,6 +208,6 @@ def test_after_process_message_with_lazy_loaded_workflow(self, mock_time): # Assertions self.assertEqual(len(storage.retrieve_calls), 1) self.assertEqual(storage.retrieve_calls[0], callbacks_ref) - self.assertTrue(lazy_workflow_obj.loaded) + self.assertIn(workflow_ref, storage.loaded_workflows) self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None) diff --git a/dramatiq_workflow/tests/test_workflow.py b/dramatiq_workflow/tests/test_workflow.py index cd0fab8..43978d4 100644 --- a/dramatiq_workflow/tests/test_workflow.py +++ b/dramatiq_workflow/tests/test_workflow.py @@ -7,7 +7,7 @@ from .. import Chain, Group, WithDelay, Workflow, WorkflowMiddleware from .._constants import OPTION_KEY_CALLBACKS -from .._models import LazyWorkflow, SerializedCompletionCallbacks +from .._models import SerializedCompletionCallbacks from .._serialize import serialize_workflow, unserialize_workflow from .._storage import CallbackStorage @@ -28,12 +28,12 @@ def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: return self.storage[ref] -class MyLazyWorkflow(LazyWorkflow): +class MyLazyLoader: def __init__(self, workflow: dict): self._workflow = workflow self.loaded = False - def load(self) -> dict: + def __call__(self) -> dict: self.loaded = True return self._workflow @@ -345,11 +345,11 @@ def test_unserialize_lazy_workflow(self): serialized = serialize_workflow(workflow) self.assertIsNotNone(serialized) - lazy_workflow = MyLazyWorkflow(serialized) - self.assertFalse(lazy_workflow.loaded) + lazy_loader = MyLazyLoader(serialized) + self.assertFalse(lazy_loader.loaded) - unserialized = unserialize_workflow(lazy_workflow) - self.assertTrue(lazy_workflow.loaded) + unserialized = unserialize_workflow(lazy_loader) + self.assertTrue(lazy_loader.loaded) self.assertEqual(workflow, unserialized) @mock.patch("dramatiq_workflow._base.time.time") From 9a364f7264df1f36e2212a973ad8be5248ffe8a7 Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Wed, 18 Jun 2025 09:27:02 -0700 Subject: [PATCH 05/10] DedupWorkflowCallbackStorage --- dramatiq_workflow/__init__.py | 3 +- dramatiq_workflow/_storage.py | 66 +++++++++++++++ dramatiq_workflow/tests/test_storage.py | 106 +++++++++++++++++++++++- 3 files changed, 173 insertions(+), 2 deletions(-) diff --git a/dramatiq_workflow/__init__.py b/dramatiq_workflow/__init__.py index 613689a..4f10d08 100644 --- a/dramatiq_workflow/__init__.py +++ b/dramatiq_workflow/__init__.py @@ -10,11 +10,12 @@ WithDelay, WorkflowType, ) -from ._storage import CallbackStorage, InlineCallbackStorage +from ._storage import CallbackStorage, DedupWorkflowCallbackStorage, InlineCallbackStorage __all__ = [ "CallbackStorage", "Chain", + "DedupWorkflowCallbackStorage", "Group", "InlineCallbackStorage", "LazyWorkflow", diff --git a/dramatiq_workflow/_storage.py b/dramatiq_workflow/_storage.py index b12d24f..21c8708 100644 --- a/dramatiq_workflow/_storage.py +++ b/dramatiq_workflow/_storage.py @@ -1,4 +1,5 @@ import abc +from functools import partial from typing import Any from ._models import SerializedCompletionCallbacks @@ -72,3 +73,68 @@ def store(self, callbacks: SerializedCompletionCallbacks) -> SerializedCompletio def retrieve(self, ref: SerializedCompletionCallbacks) -> SerializedCompletionCallbacks: return ref + + +class DedupWorkflowCallbackStorage(CallbackStorage, abc.ABC): + """ + An abstract storage backend that separates storage of workflows from + callbacks, allowing for deduplication of workflows. + """ + + @abc.abstractmethod + def _store_workflow(self, id: str, workflow: dict) -> Any: + """ + Stores a workflow and returns a reference to it. The `id` can be used + to deduplicate workflows, and the `workflow` is the actual workflow to + store. The reference returned must be serializable by the broker's + encoder (e.g. JSON). + """ + raise NotImplementedError + + @abc.abstractmethod + def _load_workflow(self, id: str, ref: Any) -> dict: + """ + Loads a workflow using the deduplication ID and reference previously + returned by `store_workflow`. + """ + raise NotImplementedError + + def _store_callbacks(self, callbacks: list[tuple[str, Any | None, bool]]) -> Any: + """ + Stores the callbacks, which may include references to workflows. By + default, this implementation simply returns the callbacks as-is to be stored inline. + """ + return callbacks + + def _retrieve_callbacks(self, ref: Any) -> list[tuple[str, Any | None, bool]]: + """ + Retrieves callbacks from a reference. By default, this implementation + simply returns the reference as-is since the default `_store_callbacks` + implementation returns the callbacks inline. + """ + return ref + + def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + """ + Stores callbacks, offloading workflow storage to `store_workflow`. + """ + new_callbacks = [] + for completion_id, remaining_workflow, is_group in callbacks: + if isinstance(remaining_workflow, dict): + remaining_workflow = self._store_workflow(completion_id, remaining_workflow) + new_callbacks.append((completion_id, remaining_workflow, is_group)) + + return self._store_callbacks(new_callbacks) + + def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + """ + Retrieves callbacks and prepares lazy loaders for workflows. + """ + callbacks = self._retrieve_callbacks(ref) + new_callbacks = [] + for completion_id, workflow_ref, is_group in callbacks: + if workflow_ref is not None and not callable(workflow_ref): + workflow_ref = partial(self._load_workflow, completion_id, workflow_ref) + new_callbacks.append((completion_id, workflow_ref, is_group)) + + return new_callbacks diff --git a/dramatiq_workflow/tests/test_storage.py b/dramatiq_workflow/tests/test_storage.py index 3ad06c9..a837fba 100644 --- a/dramatiq_workflow/tests/test_storage.py +++ b/dramatiq_workflow/tests/test_storage.py @@ -1,7 +1,8 @@ import unittest +from typing import Any from .._models import SerializedCompletionCallbacks -from .._storage import InlineCallbackStorage +from .._storage import DedupWorkflowCallbackStorage, InlineCallbackStorage class CallbackStorageTests(unittest.TestCase): @@ -32,3 +33,106 @@ def test_determine_dedup_key_for_nested_callbacks(self): def test_determine_dedup_key_with_empty_list_raises_assertion_error(self): with self.assertRaises(AssertionError): self.storage._determine_dedup_key([]) + + +class MyDedupStorage(DedupWorkflowCallbackStorage): + def __init__(self): + self.workflows = {} + self.callbacks = {} + self.workflow_store_calls = [] + self.workflow_load_calls = [] + self.callback_store_calls = [] + self.callback_retrieve_calls = [] + + def _store_workflow(self, id: str, workflow: dict) -> Any: + self.workflow_store_calls.append((id, workflow)) + ref = f"workflow-ref-{id}" + self.workflows[ref] = workflow + return ref + + def _load_workflow(self, id: str, ref: Any) -> dict: + self.workflow_load_calls.append((id, ref)) + return self.workflows[ref] + + def _store_callbacks(self, callbacks: list[tuple[str, Any | None, bool]]) -> Any: + self.callback_store_calls.append(callbacks) + ref = f"callback-ref-{len(self.callbacks)}" + self.callbacks[ref] = callbacks + return ref + + def _retrieve_callbacks(self, ref: Any) -> list[tuple[str, Any | None, bool]]: + self.callback_retrieve_calls.append(ref) + return self.callbacks[ref] + + +class DedupWorkflowCallbackStorageTests(unittest.TestCase): + def setUp(self): + self.storage = MyDedupStorage() + + def test_store_and_retrieve(self): + workflow_dict = {"__type__": "chain", "children": []} + callbacks: SerializedCompletionCallbacks = [ + ("id1", workflow_dict, False), + ("id2", None, True), + ] + + # Store callbacks + callbacks_ref = self.storage.store(callbacks) + self.assertEqual(callbacks_ref, "callback-ref-0") + + # Check what was stored + self.assertEqual(len(self.storage.workflow_store_calls), 1) + self.assertEqual(self.storage.workflow_store_calls[0], ("id1", workflow_dict)) + + self.assertEqual(len(self.storage.callback_store_calls), 1) + stored_callbacks = self.storage.callback_store_calls[0] + self.assertEqual(len(stored_callbacks), 2) + self.assertEqual(stored_callbacks[0], ("id1", "workflow-ref-id1", False)) + self.assertEqual(stored_callbacks[1], ("id2", None, True)) + + # Retrieve callbacks + retrieved_callbacks = self.storage.retrieve(callbacks_ref) + self.assertEqual(len(self.storage.callback_retrieve_calls), 1) + self.assertEqual(self.storage.callback_retrieve_calls[0], callbacks_ref) + + self.assertEqual(len(retrieved_callbacks), 2) + + # Check first callback (with workflow) + id1, loader1, is_group1 = retrieved_callbacks[0] + self.assertEqual(id1, "id1") + self.assertFalse(is_group1) + self.assertTrue(callable(loader1)) + + # Check second callback (without workflow) + id2, loader2, is_group2 = retrieved_callbacks[1] + self.assertEqual(id2, "id2") + self.assertTrue(is_group2) + self.assertIsNone(loader2) + + # Load the workflow + self.assertEqual(len(self.storage.workflow_load_calls), 0) + loaded_workflow = loader1() + self.assertEqual(len(self.storage.workflow_load_calls), 1) + self.assertEqual(self.storage.workflow_load_calls[0], ("id1", "workflow-ref-id1")) + self.assertEqual(loaded_workflow, workflow_dict) + + def test_retrieve_does_not_wrap_callable(self): + def lazy_workflow(): + return {"__type__": "chain", "children": []} + + callbacks: SerializedCompletionCallbacks = [ + ("id1", lazy_workflow, False), + ] + + # Store should not call _store_workflow + callbacks_ref = self.storage.store(callbacks) + self.assertEqual(len(self.storage.workflow_store_calls), 0) + + stored_callbacks = self.storage.callback_store_calls[0] + self.assertIs(stored_callbacks[0][1], lazy_workflow) + + # Retrieve should not wrap the callable + retrieved_callbacks = self.storage.retrieve(callbacks_ref) + + id1, loader1, is_group1 = retrieved_callbacks[0] + self.assertIs(loader1, lazy_workflow) From 93dd33d312d66259e9e5574641e1cef8359b99ab Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Wed, 18 Jun 2025 10:02:34 -0700 Subject: [PATCH 06/10] Make sure already lazy loaded workflow can be stored --- dramatiq_workflow/_storage.py | 20 ++++++++++++++--- dramatiq_workflow/tests/test_storage.py | 29 +++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 3 deletions(-) diff --git a/dramatiq_workflow/_storage.py b/dramatiq_workflow/_storage.py index 21c8708..bb4557b 100644 --- a/dramatiq_workflow/_storage.py +++ b/dramatiq_workflow/_storage.py @@ -2,7 +2,7 @@ from functools import partial from typing import Any -from ._models import SerializedCompletionCallbacks +from ._models import LazyWorkflow, SerializedCompletionCallbacks class CallbackStorage(abc.ABC): @@ -75,6 +75,15 @@ def retrieve(self, ref: SerializedCompletionCallbacks) -> SerializedCompletionCa return ref +class _LazyLoadedWorkflow: + def __init__(self, ref: Any, load_func: LazyWorkflow): + self.ref = ref + self.load_func = load_func + + def __call__(self) -> dict: + return self.load_func() + + class DedupWorkflowCallbackStorage(CallbackStorage, abc.ABC): """ An abstract storage backend that separates storage of workflows from @@ -120,7 +129,9 @@ def store(self, callbacks: SerializedCompletionCallbacks) -> Any: """ new_callbacks = [] for completion_id, remaining_workflow, is_group in callbacks: - if isinstance(remaining_workflow, dict): + if isinstance(remaining_workflow, _LazyLoadedWorkflow): + remaining_workflow = remaining_workflow.ref + elif isinstance(remaining_workflow, dict): remaining_workflow = self._store_workflow(completion_id, remaining_workflow) new_callbacks.append((completion_id, remaining_workflow, is_group)) @@ -134,7 +145,10 @@ def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: new_callbacks = [] for completion_id, workflow_ref, is_group in callbacks: if workflow_ref is not None and not callable(workflow_ref): - workflow_ref = partial(self._load_workflow, completion_id, workflow_ref) + workflow_ref = _LazyLoadedWorkflow( + ref=workflow_ref, + load_func=partial(self._load_workflow, completion_id, workflow_ref), + ) new_callbacks.append((completion_id, workflow_ref, is_group)) return new_callbacks diff --git a/dramatiq_workflow/tests/test_storage.py b/dramatiq_workflow/tests/test_storage.py index a837fba..23b3532 100644 --- a/dramatiq_workflow/tests/test_storage.py +++ b/dramatiq_workflow/tests/test_storage.py @@ -136,3 +136,32 @@ def lazy_workflow(): id1, loader1, is_group1 = retrieved_callbacks[0] self.assertIs(loader1, lazy_workflow) + + def test_store_with_already_lazy_loaded_workflow(self): + # This test ensures that when we store a workflow that has already been + # loaded and wrapped by the storage, we don't try to store it again, + # but instead use its reference. + + # 1. Store a workflow and retrieve it. + workflow_dict = {"__type__": "chain", "children": []} + callbacks1: SerializedCompletionCallbacks = [("id1", workflow_dict, False)] + callbacks_ref1 = self.storage.store(callbacks1) + retrieved_callbacks1 = self.storage.retrieve(callbacks_ref1) + _, lazy_loader, _ = retrieved_callbacks1[0] + + # We should have one call to store the workflow + self.assertEqual(len(self.storage.workflow_store_calls), 1) + self.assertTrue(callable(lazy_loader)) + + # 2. Now, create new callbacks using the lazy loader from the previous step + # and store them. + callbacks2: SerializedCompletionCallbacks = [("id2", lazy_loader, False)] + callbacks_ref2 = self.storage.store(callbacks2) + + # 3. _store_workflow should NOT have been called again. + self.assertEqual(len(self.storage.workflow_store_calls), 1) + + # 4. The new stored callbacks should contain the original workflow reference. + stored_callbacks2 = self.storage.callbacks[callbacks_ref2] + self.assertEqual(len(stored_callbacks2), 1) + self.assertEqual(stored_callbacks2[0], ("id2", "workflow-ref-id1", False)) From c95cf14c50a6676c5f51e26547cbe4c93104db42 Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Wed, 18 Jun 2025 10:21:10 -0700 Subject: [PATCH 07/10] Remove _determine_dedup_key and update README --- README.md | 47 +++++++++------ dramatiq_workflow/_storage.py | 22 ------- dramatiq_workflow/tests/test_middleware.py | 39 ------------ dramatiq_workflow/tests/test_storage.py | 32 +--------- dramatiq_workflow/tests/test_workflow.py | 70 ---------------------- 5 files changed, 31 insertions(+), 179 deletions(-) diff --git a/README.md b/README.md index 0ba0aab..62ff7d1 100644 --- a/README.md +++ b/README.md @@ -237,29 +237,42 @@ storage = MyS3Storage() # Your custom storage backend broker.add_middleware(WorkflowMiddleware(backend, callback_storage=storage)) ``` -#### Deduplicating Callbacks +##### Deduplicating Workflows -For convenience, `CallbackStorage` provides a helper method `_determine_dedup_key` -that you can use to deduplicate callbacks for `Group` tasks. If the -deduplication key already exists, storing can be skipped. +For convenience, `dramatiq-workflow` provides an abstract +`DedupWorkflowCallbackStorage` class that you can use to separate the storage +of workflows from the storage of callbacks. This is useful for deduplicating +large workflow definitions that may be part of multiple callbacks, especially +when chaining large groups of tasks. + +To use it, you need to subclass `DedupWorkflowCallbackStorage` and implement +the `_store_workflow` and `_load_workflow` methods. ```python from typing import Any -from dramatiq_workflow import CallbackStorage, SerializedCompletionCallbacks +from dramatiq_workflow import DedupWorkflowCallbackStorage -class MyDedupStorage(CallbackStorage): +class MyDedupStorage(DedupWorkflowCallbackStorage): def __init__(self): - self.__storage = MagicStorage() - - def store(self, callbacks: SerializedCompletionCallbacks) -> str: - dedup_key, _ = self._determine_dedup_key(callbacks) - if not self.__storage.exists(dedup_key): - self.__storage.set(dedup_key, callbacks) - return dedup_key - - def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: - # ref will be the deduplication key from before - return self.__storage.get(ref) + # In a real application, this would be a persistent storage like a + # database or a distributed cache so that workers and producers can + # both access it. + self.__workflow_storage = {} + + def _store_workflow(self, id: str, workflow: dict) -> Any: + # Using the `id` (which is the completion ID) to deduplicate. + workflow_key = id + if workflow_key not in self.__workflow_storage: + self.__workflow_storage[workflow_key] = workflow + return workflow_key # Return a reference to the workflow. + + def _load_workflow(self, id: str, ref: Any) -> dict: + # `ref` is what `_store_workflow` returned. + return self.__workflow_storage[ref] + +# You can also override `_store_callbacks` and `_retrieve_callbacks` to store +# callbacks separately, for example in a database or S3. You can deduplicate +# these as well by using the ID of the last callback, i.e. callbacks[-1][0]. ``` ### Barrier diff --git a/dramatiq_workflow/_storage.py b/dramatiq_workflow/_storage.py index bb4557b..ad9cdfa 100644 --- a/dramatiq_workflow/_storage.py +++ b/dramatiq_workflow/_storage.py @@ -39,28 +39,6 @@ def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: """ raise NotImplementedError - def _determine_dedup_key(self, callbacks: SerializedCompletionCallbacks) -> tuple[str, bool]: - """ - Determines a deduplication key for the given callbacks. - - This is used by deduplication storage backends to identify unique - callback sets. - - Returns: - A tuple containing the completion ID and a boolean indicating if - the callbacks are part of a group (i.e. if deduplication is - strictly needed). - """ - - # NOTE: `Workflow.__augment_message` only calls the `CallbackStorage` - # when the `callbacks` list is not empty. This `assert` should always - # hold true. - assert isinstance(callbacks, list) and len(callbacks) > 0, "Callbacks must be a non-empty list" - - last_callback = callbacks[-1] - completion_id, _, is_group = last_callback - return completion_id, is_group - class InlineCallbackStorage(CallbackStorage): """ diff --git a/dramatiq_workflow/tests/test_middleware.py b/dramatiq_workflow/tests/test_middleware.py index 0484705..6802757 100644 --- a/dramatiq_workflow/tests/test_middleware.py +++ b/dramatiq_workflow/tests/test_middleware.py @@ -14,22 +14,6 @@ from dramatiq_workflow._storage import CallbackStorage -class MyDedupStorage(CallbackStorage): - def __init__(self): - self.storage = {} - self.retrieve_calls = [] - - def store(self, callbacks: SerializedCompletionCallbacks) -> Any: - dedup_key, _ = self._determine_dedup_key(callbacks) - if dedup_key not in self.storage: - self.storage[dedup_key] = callbacks - return dedup_key - - def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: - self.retrieve_calls.append(ref) - return self.storage[ref] - - class MyLazyStorage(CallbackStorage): def __init__(self): self.workflows = {} @@ -154,29 +138,6 @@ def test_after_process_message_with_barriered_workflow(self, mock_time): self.middleware.after_process_message(self.broker, message) self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None) - @mock.patch("dramatiq_workflow._base.time.time") - def test_after_process_message_with_custom_storage(self, mock_time): - mock_time.return_value = 1337 - storage = MyDedupStorage() - self.middleware = WorkflowMiddleware(self.rate_limiter_backend, callback_storage=storage) - - serialized_workflow = self._create_serialized_workflow() - callbacks = [("barrier_1", serialized_workflow, True)] - - barrier = AtMostOnceBarrier(self.rate_limiter_backend, "barrier_1") - barrier.create(1) - - dedup_key = storage.store(callbacks) - - message = self._make_message({OPTION_KEY_CALLBACKS: dedup_key}) - - self.middleware.after_process_message(self.broker, message) - - self.assertEqual(len(storage.retrieve_calls), 1) - self.assertEqual(storage.retrieve_calls[0], dedup_key) - - self.broker.enqueue.assert_called_once_with(self._make_message(message_timestamp=1337_000)._message, delay=None) - @mock.patch("dramatiq_workflow._base.time.time") def test_after_process_message_with_lazy_loaded_workflow(self, mock_time): mock_time.return_value = 1337 diff --git a/dramatiq_workflow/tests/test_storage.py b/dramatiq_workflow/tests/test_storage.py index 23b3532..8467114 100644 --- a/dramatiq_workflow/tests/test_storage.py +++ b/dramatiq_workflow/tests/test_storage.py @@ -2,37 +2,7 @@ from typing import Any from .._models import SerializedCompletionCallbacks -from .._storage import DedupWorkflowCallbackStorage, InlineCallbackStorage - - -class CallbackStorageTests(unittest.TestCase): - def setUp(self): - self.storage = InlineCallbackStorage() - - def test_determine_dedup_key_for_group(self): - callbacks: SerializedCompletionCallbacks = [("group-id-123", None, True)] - dedup_key, is_group = self.storage._determine_dedup_key(callbacks) - self.assertEqual(dedup_key, "group-id-123") - self.assertTrue(is_group) - - def test_determine_dedup_key_for_chain(self): - callbacks: SerializedCompletionCallbacks = [("chain-id-456", {"__type__": "chain"}, False)] - dedup_key, is_group = self.storage._determine_dedup_key(callbacks) - self.assertEqual(dedup_key, "chain-id-456") - self.assertFalse(is_group) - - def test_determine_dedup_key_for_nested_callbacks(self): - callbacks: SerializedCompletionCallbacks = [ - ("chain-id-456", {"__type__": "chain"}, False), - ("group-id-123", None, True), - ] - dedup_key, is_group = self.storage._determine_dedup_key(callbacks) - self.assertEqual(dedup_key, "group-id-123") - self.assertTrue(is_group) - - def test_determine_dedup_key_with_empty_list_raises_assertion_error(self): - with self.assertRaises(AssertionError): - self.storage._determine_dedup_key([]) +from .._storage import DedupWorkflowCallbackStorage class MyDedupStorage(DedupWorkflowCallbackStorage): diff --git a/dramatiq_workflow/tests/test_workflow.py b/dramatiq_workflow/tests/test_workflow.py index 43978d4..11aebdd 100644 --- a/dramatiq_workflow/tests/test_workflow.py +++ b/dramatiq_workflow/tests/test_workflow.py @@ -1,31 +1,11 @@ import unittest -from typing import Any from unittest import mock import dramatiq import dramatiq.rate_limits from .. import Chain, Group, WithDelay, Workflow, WorkflowMiddleware -from .._constants import OPTION_KEY_CALLBACKS -from .._models import SerializedCompletionCallbacks from .._serialize import serialize_workflow, unserialize_workflow -from .._storage import CallbackStorage - - -class MyDedupStorage(CallbackStorage): - def __init__(self): - self.storage = {} - self.store_calls = [] - - def store(self, callbacks: SerializedCompletionCallbacks) -> Any: - dedup_key, is_group = self._determine_dedup_key(callbacks) - self.store_calls.append((dedup_key, callbacks, is_group)) - if dedup_key not in self.storage: - self.storage[dedup_key] = callbacks - return dedup_key - - def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: - return self.storage[ref] class MyLazyLoader: @@ -431,53 +411,3 @@ def test_nested_delays(self, time_mock): ), delay=20, ) - - @mock.patch("dramatiq_workflow._base.time.time") - def test_workflow_with_custom_storage(self, time_mock): - time_mock.return_value = 1717526000.12 - - storage = MyDedupStorage() - # The broker is a mock object, we can just replace the middleware list - self.broker.middleware = [ - WorkflowMiddleware( - rate_limiter_backend=self.rate_limiter_backend, - barrier_type=self.barrier, - callback_storage=storage, - ) - ] - - workflow = Workflow(Group(self.task.message(), self.task.message()), broker=self.broker) - workflow.run() - - # Assertions - self.assertEqual(len(storage.store_calls), 2) - - dedup_key1, callbacks1, is_group1 = storage.store_calls[0] - dedup_key2, callbacks2, is_group2 = storage.store_calls[1] - self.assertEqual(dedup_key1, dedup_key2) - self.assertEqual(callbacks1, callbacks2) - self.assertTrue(is_group1) - self.assertTrue(is_group2) - - self.assertEqual(len(storage.storage), 1) - self.assertIn(dedup_key1, storage.storage) - - self.broker.enqueue.assert_has_calls( - [ - mock.call(mock.ANY, delay=None), - mock.call(mock.ANY, delay=None), - ], - any_order=True, - ) - - # Check the options passed to enqueue - self.assertEqual(len(self.broker.enqueue.call_args_list), 2) - message1 = self.broker.enqueue.call_args_list[0][0][0] - delay1 = self.broker.enqueue.call_args_list[0][1]["delay"] - message2 = self.broker.enqueue.call_args_list[1][0][0] - delay2 = self.broker.enqueue.call_args_list[1][1]["delay"] - - self.assertEqual(message1.options[OPTION_KEY_CALLBACKS], dedup_key1) - self.assertEqual(delay1, None) - self.assertEqual(message2.options[OPTION_KEY_CALLBACKS], dedup_key2) - self.assertEqual(delay2, None) From 08acc39e222f25f1ee3fc70d6e04d6e0294c5585 Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Wed, 18 Jun 2025 15:33:37 -0700 Subject: [PATCH 08/10] Add __str__ --- dramatiq_workflow/_storage.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dramatiq_workflow/_storage.py b/dramatiq_workflow/_storage.py index ad9cdfa..b9a8d23 100644 --- a/dramatiq_workflow/_storage.py +++ b/dramatiq_workflow/_storage.py @@ -61,6 +61,9 @@ def __init__(self, ref: Any, load_func: LazyWorkflow): def __call__(self) -> dict: return self.load_func() + def __str__(self): + return f"_LazyLoadedWorkflow({self.ref})" + class DedupWorkflowCallbackStorage(CallbackStorage, abc.ABC): """ From 5e37970b6857082491d326219cb4c3bb8129599d Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Mon, 23 Jun 2025 11:11:44 -0700 Subject: [PATCH 09/10] Fix tests --- dramatiq_workflow/tests/test_workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dramatiq_workflow/tests/test_workflow.py b/dramatiq_workflow/tests/test_workflow.py index 44dfd14..8b20ee3 100644 --- a/dramatiq_workflow/tests/test_workflow.py +++ b/dramatiq_workflow/tests/test_workflow.py @@ -462,7 +462,7 @@ def test_nested_delays(self, time_mock): ) def test_middleware_is_cached(self): - workflow = Workflow(Chain(self.task.message(), self.task.message()), broker=self.broker) + workflow = Workflow(Chain(), broker=self.broker) # Access middleware property multiple times workflow.run() From 64cc2bc1642c4907b45948bccc4fc69d70f35955 Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Mon, 23 Jun 2025 11:53:22 -0700 Subject: [PATCH 10/10] Simplify and add typing for references --- README.md | 4 -- dramatiq_workflow/_storage.py | 64 ++++++++++++------------- dramatiq_workflow/tests/test_storage.py | 48 ++++--------------- 3 files changed, 38 insertions(+), 78 deletions(-) diff --git a/README.md b/README.md index 62ff7d1..8a97d0c 100644 --- a/README.md +++ b/README.md @@ -269,10 +269,6 @@ class MyDedupStorage(DedupWorkflowCallbackStorage): def _load_workflow(self, id: str, ref: Any) -> dict: # `ref` is what `_store_workflow` returned. return self.__workflow_storage[ref] - -# You can also override `_store_callbacks` and `_retrieve_callbacks` to store -# callbacks separately, for example in a database or S3. You can deduplicate -# these as well by using the ID of the last callback, i.e. callbacks[-1][0]. ``` ### Barrier diff --git a/dramatiq_workflow/_storage.py b/dramatiq_workflow/_storage.py index b9a8d23..11d2fbe 100644 --- a/dramatiq_workflow/_storage.py +++ b/dramatiq_workflow/_storage.py @@ -1,17 +1,19 @@ import abc from functools import partial -from typing import Any +from typing import Any, Generic, TypeVar, cast from ._models import LazyWorkflow, SerializedCompletionCallbacks +CallbacksRefT = TypeVar("CallbacksRefT") -class CallbackStorage(abc.ABC): + +class CallbackStorage(abc.ABC, Generic[CallbacksRefT]): """ Abstract base class for callback storage backends. """ @abc.abstractmethod - def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + def store(self, callbacks: SerializedCompletionCallbacks) -> CallbacksRefT: """ Stores callbacks and returns a reference to them. @@ -27,7 +29,7 @@ def store(self, callbacks: SerializedCompletionCallbacks) -> Any: raise NotImplementedError @abc.abstractmethod - def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + def retrieve(self, ref: CallbacksRefT) -> SerializedCompletionCallbacks: """ Retrieves callbacks using a reference. @@ -40,7 +42,7 @@ def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: raise NotImplementedError -class InlineCallbackStorage(CallbackStorage): +class InlineCallbackStorage(CallbackStorage[SerializedCompletionCallbacks]): """ A storage backend that stores callbacks inline with the message. This is the default storage backend. @@ -53,7 +55,11 @@ def retrieve(self, ref: SerializedCompletionCallbacks) -> SerializedCompletionCa return ref -class _LazyLoadedWorkflow: +WorkflowRefT = TypeVar("WorkflowRefT") +CompletionCallbacksWithWorkflowRef = list[tuple[str, WorkflowRefT | None, bool]] + + +class _LazyLoadedWorkflow(Generic[WorkflowRefT]): def __init__(self, ref: Any, load_func: LazyWorkflow): self.ref = ref self.load_func = load_func @@ -65,14 +71,14 @@ def __str__(self): return f"_LazyLoadedWorkflow({self.ref})" -class DedupWorkflowCallbackStorage(CallbackStorage, abc.ABC): +class DedupWorkflowCallbackStorage(CallbackStorage[CompletionCallbacksWithWorkflowRef], abc.ABC, Generic[WorkflowRefT]): """ An abstract storage backend that separates storage of workflows from callbacks, allowing for deduplication of workflows. """ @abc.abstractmethod - def _store_workflow(self, id: str, workflow: dict) -> Any: + def _store_workflow(self, id: str, workflow: dict) -> WorkflowRefT: """ Stores a workflow and returns a reference to it. The `id` can be used to deduplicate workflows, and the `workflow` is the actual workflow to @@ -82,49 +88,39 @@ def _store_workflow(self, id: str, workflow: dict) -> Any: raise NotImplementedError @abc.abstractmethod - def _load_workflow(self, id: str, ref: Any) -> dict: + def _load_workflow(self, id: str, ref: WorkflowRefT) -> dict: """ Loads a workflow using the deduplication ID and reference previously returned by `store_workflow`. """ raise NotImplementedError - def _store_callbacks(self, callbacks: list[tuple[str, Any | None, bool]]) -> Any: - """ - Stores the callbacks, which may include references to workflows. By - default, this implementation simply returns the callbacks as-is to be stored inline. - """ - return callbacks - - def _retrieve_callbacks(self, ref: Any) -> list[tuple[str, Any | None, bool]]: - """ - Retrieves callbacks from a reference. By default, this implementation - simply returns the reference as-is since the default `_store_callbacks` - implementation returns the callbacks inline. - """ - return ref - - def store(self, callbacks: SerializedCompletionCallbacks) -> Any: + def store(self, callbacks: SerializedCompletionCallbacks) -> CompletionCallbacksWithWorkflowRef[WorkflowRefT]: """ Stores callbacks, offloading workflow storage to `store_workflow`. """ - new_callbacks = [] + new_callbacks: CompletionCallbacksWithWorkflowRef[WorkflowRefT] = [] for completion_id, remaining_workflow, is_group in callbacks: + remaining_workflow_ref = None if isinstance(remaining_workflow, _LazyLoadedWorkflow): - remaining_workflow = remaining_workflow.ref + remaining_workflow_ref = cast(WorkflowRefT, remaining_workflow.ref) elif isinstance(remaining_workflow, dict): - remaining_workflow = self._store_workflow(completion_id, remaining_workflow) - new_callbacks.append((completion_id, remaining_workflow, is_group)) + remaining_workflow_ref = self._store_workflow(completion_id, remaining_workflow) + elif remaining_workflow is not None: + raise TypeError( + "Unsupported workflow type: " + f"{type(remaining_workflow)}. Expected None, dict, or _LazyLoadedWorkflow." + ) + new_callbacks.append((completion_id, remaining_workflow_ref, is_group)) - return self._store_callbacks(new_callbacks) + return new_callbacks - def retrieve(self, ref: Any) -> SerializedCompletionCallbacks: + def retrieve(self, ref: CompletionCallbacksWithWorkflowRef[WorkflowRefT]) -> SerializedCompletionCallbacks: """ Retrieves callbacks and prepares lazy loaders for workflows. """ - callbacks = self._retrieve_callbacks(ref) - new_callbacks = [] - for completion_id, workflow_ref, is_group in callbacks: + new_callbacks: SerializedCompletionCallbacks = [] + for completion_id, workflow_ref, is_group in ref: if workflow_ref is not None and not callable(workflow_ref): workflow_ref = _LazyLoadedWorkflow( ref=workflow_ref, diff --git a/dramatiq_workflow/tests/test_storage.py b/dramatiq_workflow/tests/test_storage.py index 8467114..13a2284 100644 --- a/dramatiq_workflow/tests/test_storage.py +++ b/dramatiq_workflow/tests/test_storage.py @@ -11,8 +11,6 @@ def __init__(self): self.callbacks = {} self.workflow_store_calls = [] self.workflow_load_calls = [] - self.callback_store_calls = [] - self.callback_retrieve_calls = [] def _store_workflow(self, id: str, workflow: dict) -> Any: self.workflow_store_calls.append((id, workflow)) @@ -24,16 +22,6 @@ def _load_workflow(self, id: str, ref: Any) -> dict: self.workflow_load_calls.append((id, ref)) return self.workflows[ref] - def _store_callbacks(self, callbacks: list[tuple[str, Any | None, bool]]) -> Any: - self.callback_store_calls.append(callbacks) - ref = f"callback-ref-{len(self.callbacks)}" - self.callbacks[ref] = callbacks - return ref - - def _retrieve_callbacks(self, ref: Any) -> list[tuple[str, Any | None, bool]]: - self.callback_retrieve_calls.append(ref) - return self.callbacks[ref] - class DedupWorkflowCallbackStorageTests(unittest.TestCase): def setUp(self): @@ -48,23 +36,14 @@ def test_store_and_retrieve(self): # Store callbacks callbacks_ref = self.storage.store(callbacks) - self.assertEqual(callbacks_ref, "callback-ref-0") + self.assertEqual(callbacks_ref, [("id1", "workflow-ref-id1", False), ("id2", None, True)]) # Check what was stored self.assertEqual(len(self.storage.workflow_store_calls), 1) self.assertEqual(self.storage.workflow_store_calls[0], ("id1", workflow_dict)) - self.assertEqual(len(self.storage.callback_store_calls), 1) - stored_callbacks = self.storage.callback_store_calls[0] - self.assertEqual(len(stored_callbacks), 2) - self.assertEqual(stored_callbacks[0], ("id1", "workflow-ref-id1", False)) - self.assertEqual(stored_callbacks[1], ("id2", None, True)) - # Retrieve callbacks retrieved_callbacks = self.storage.retrieve(callbacks_ref) - self.assertEqual(len(self.storage.callback_retrieve_calls), 1) - self.assertEqual(self.storage.callback_retrieve_calls[0], callbacks_ref) - self.assertEqual(len(retrieved_callbacks), 2) # Check first callback (with workflow) @@ -81,31 +60,22 @@ def test_store_and_retrieve(self): # Load the workflow self.assertEqual(len(self.storage.workflow_load_calls), 0) + assert callable(loader1) loaded_workflow = loader1() self.assertEqual(len(self.storage.workflow_load_calls), 1) self.assertEqual(self.storage.workflow_load_calls[0], ("id1", "workflow-ref-id1")) self.assertEqual(loaded_workflow, workflow_dict) - def test_retrieve_does_not_wrap_callable(self): - def lazy_workflow(): + def test_store_with_unsupported_workflow_type(self): + def unsupported_lazy_workflow(): return {"__type__": "chain", "children": []} callbacks: SerializedCompletionCallbacks = [ - ("id1", lazy_workflow, False), + ("id1", unsupported_lazy_workflow, False), ] - # Store should not call _store_workflow - callbacks_ref = self.storage.store(callbacks) - self.assertEqual(len(self.storage.workflow_store_calls), 0) - - stored_callbacks = self.storage.callback_store_calls[0] - self.assertIs(stored_callbacks[0][1], lazy_workflow) - - # Retrieve should not wrap the callable - retrieved_callbacks = self.storage.retrieve(callbacks_ref) - - id1, loader1, is_group1 = retrieved_callbacks[0] - self.assertIs(loader1, lazy_workflow) + with self.assertRaises(TypeError): + self.storage.store(callbacks) def test_store_with_already_lazy_loaded_workflow(self): # This test ensures that when we store a workflow that has already been @@ -132,6 +102,4 @@ def test_store_with_already_lazy_loaded_workflow(self): self.assertEqual(len(self.storage.workflow_store_calls), 1) # 4. The new stored callbacks should contain the original workflow reference. - stored_callbacks2 = self.storage.callbacks[callbacks_ref2] - self.assertEqual(len(stored_callbacks2), 1) - self.assertEqual(stored_callbacks2[0], ("id2", "workflow-ref-id1", False)) + self.assertEqual(callbacks_ref2[0][1], callbacks_ref1[0][1])