From 37198cd64a3ecb5933c9864b73cf5a4e506e5cc1 Mon Sep 17 00:00:00 2001 From: Nils Caspar Date: Tue, 17 Jun 2025 12:11:28 -0700 Subject: [PATCH] Cache middleware properly --- dramatiq_workflow/_base.py | 6 +++--- dramatiq_workflow/tests/test_workflow.py | 26 ++++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/dramatiq_workflow/_base.py b/dramatiq_workflow/_base.py index c5d7e2c..3ac58b2 100644 --- a/dramatiq_workflow/_base.py +++ b/dramatiq_workflow/_base.py @@ -164,10 +164,10 @@ def __augment_message(self, message: Message, completion_callbacks: SerializedCo @property def __middleware(self) -> WorkflowMiddleware: - if not hasattr(self, "__cached_middleware"): + if not hasattr(self, "_cached_middleware"): for middleware in self.broker.middleware: if isinstance(middleware, WorkflowMiddleware): - self.__cached_middleware = middleware + self._cached_middleware = middleware break else: raise RuntimeError( @@ -175,7 +175,7 @@ def __middleware(self) -> WorkflowMiddleware: "to set it up? It is required if you want to use " "workflows." ) - return self.__cached_middleware + return self._cached_middleware @property def __rate_limiter_backend(self) -> dramatiq.rate_limits.RateLimiterBackend: diff --git a/dramatiq_workflow/tests/test_workflow.py b/dramatiq_workflow/tests/test_workflow.py index 3f8f71d..8c023a6 100644 --- a/dramatiq_workflow/tests/test_workflow.py +++ b/dramatiq_workflow/tests/test_workflow.py @@ -12,14 +12,18 @@ class WorkflowTests(unittest.TestCase): def setUp(self): self.rate_limiter_backend = mock.create_autospec(dramatiq.rate_limits.RateLimiterBackend, instance=True) self.barrier = mock.create_autospec(dramatiq.rate_limits.Barrier) - self.broker = mock.MagicMock( - middleware=[ - WorkflowMiddleware( - rate_limiter_backend=self.rate_limiter_backend, - barrier_type=self.barrier, - ) + self.broker = mock.MagicMock() + self.workflow_middleware = WorkflowMiddleware( + rate_limiter_backend=self.rate_limiter_backend, + barrier_type=self.barrier, + ) + self.middleware_list = mock.PropertyMock( + return_value=[ + self.workflow_middleware, ] ) + type(self.broker).middleware = self.middleware_list + self.task = mock.MagicMock() self.task.message.side_effect = lambda *args, **kwargs: self.__make_message( self.__generate_id(), *args, **kwargs @@ -389,3 +393,13 @@ def test_nested_delays(self, time_mock): ), delay=20, ) + + def test_middleware_is_cached(self): + workflow = Workflow(Chain(self.task.message(), self.task.message()), broker=self.broker) + + # Access middleware property multiple times + workflow.run() + workflow.run() + + # Check that broker.middleware was accessed only once + self.middleware_list.assert_called_once()