diff --git a/dramatiq_workflow/_base.py b/dramatiq_workflow/_base.py index 3ac58b2..2ad018a 100644 --- a/dramatiq_workflow/_base.py +++ b/dramatiq_workflow/_base.py @@ -145,6 +145,21 @@ def __workflow_with_completion_callbacks(self, task, completion_callbacks) -> "W ) def __schedule_noop(self, completion_callbacks: SerializedCompletionCallbacks): + """ + Schedules a no-op task to trigger the workflow middleware. + + This is necessary when a Chain or a Group is empty, to ensure that + the completion callbacks are still processed and the workflow can + continue. + """ + + if not self._delay: + # If there is no delay, we can process the completion callbacks + # immediately instead of scheduling a noop task. This saves us a + # round trip to the broker and having to encode the workflow. + self.__middleware._process_completion_callbacks(self.broker, completion_callbacks) + return + noop_message = workflow_noop.message() noop_message = self.__augment_message(noop_message, completion_callbacks) self.broker.enqueue(noop_message, delay=self._delay) diff --git a/dramatiq_workflow/_middleware.py b/dramatiq_workflow/_middleware.py index f751610..1f2eb5d 100644 --- a/dramatiq_workflow/_middleware.py +++ b/dramatiq_workflow/_middleware.py @@ -38,6 +38,11 @@ def after_process_message( if completion_callbacks is None: return + self._process_completion_callbacks(broker, completion_callbacks) + + def _process_completion_callbacks( + self, broker: dramatiq.Broker, completion_callbacks: SerializedCompletionCallbacks + ): # Go through the completion callbacks backwards until we hit the first non-completed barrier while len(completion_callbacks) > 0: completion_id, remaining_workflow, propagate = completion_callbacks[-1] diff --git a/dramatiq_workflow/tests/test_workflow.py b/dramatiq_workflow/tests/test_workflow.py index 8c023a6..e82974b 100644 --- a/dramatiq_workflow/tests/test_workflow.py +++ b/dramatiq_workflow/tests/test_workflow.py @@ -189,13 +189,58 @@ def test_simple_workflow(self, time_mock): delay=None, ) + def test_empty_chain_workflow(self): + middleware = self.broker.middleware[0] + middleware._process_completion_callbacks = mock.MagicMock() + + workflow = Workflow(Chain(), broker=self.broker) + workflow.run() + + self.broker.enqueue.assert_not_called() + middleware._process_completion_callbacks.assert_called_once_with(self.broker, []) + + def test_empty_group_workflow(self): + middleware = self.broker.middleware[0] + middleware._process_completion_callbacks = mock.MagicMock() + + workflow = Workflow(Group(), broker=self.broker) + workflow.run() + + self.broker.enqueue.assert_not_called() + middleware._process_completion_callbacks.assert_called_once_with(self.broker, []) + + @mock.patch("dramatiq_workflow._base.workflow_noop.message") @mock.patch("dramatiq_workflow._base.time.time") - def test_noop_workflow(self, time_mock): + def test_empty_chain_workflow_with_delay(self, time_mock, noop_message_mock): time_mock.return_value = 1717526000.12 - workflow = Workflow(Chain(), broker=self.broker) + updated_timestamp = time_mock.return_value * 1000 + + original_noop_message = self.__make_message(999) + noop_message_mock.return_value = original_noop_message + + middleware = self.broker.middleware[0] + middleware._process_completion_callbacks = mock.MagicMock() + + workflow = Workflow(WithDelay(Chain(), delay=10), broker=self.broker) workflow.run() + middleware._process_completion_callbacks.assert_not_called() + noop_message_mock.assert_called_once_with() + self.broker.enqueue.assert_called_once() + args, kwargs = self.broker.enqueue.call_args + enqueued_message = args[0] + + self.assertEqual(kwargs, {"delay": 10}) + self.assertEqual(enqueued_message.message_id, original_noop_message.message_id) + self.assertEqual(enqueued_message.message_timestamp, updated_timestamp) + self.assertEqual(enqueued_message.options, {}) + + def test_missing_middleware(self): + self.broker = mock.MagicMock(middleware=[]) + workflow = Workflow(Chain(), broker=self.broker) + with self.assertRaisesRegex(RuntimeError, "WorkflowMiddleware middleware not found"): + workflow.run() def test_unsupported_workflow(self): with self.assertRaises(TypeError):