Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions dramatiq_workflow/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,21 @@ def __workflow_with_completion_callbacks(self, task, completion_callbacks) -> "W
)

def __schedule_noop(self, completion_callbacks: SerializedCompletionCallbacks):
"""
Schedules a no-op task to trigger the workflow middleware.

This is necessary when a Chain or a Group is empty, to ensure that
the completion callbacks are still processed and the workflow can
continue.
"""

if not self._delay:
# If there is no delay, we can process the completion callbacks
# immediately instead of scheduling a noop task. This saves us a
# round trip to the broker and having to encode the workflow.
self.__middleware._process_completion_callbacks(self.broker, completion_callbacks)
return

noop_message = workflow_noop.message()
noop_message = self.__augment_message(noop_message, completion_callbacks)
self.broker.enqueue(noop_message, delay=self._delay)
Expand Down
5 changes: 5 additions & 0 deletions dramatiq_workflow/_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def after_process_message(
if completion_callbacks is None:
return

self._process_completion_callbacks(broker, completion_callbacks)

def _process_completion_callbacks(
self, broker: dramatiq.Broker, completion_callbacks: SerializedCompletionCallbacks
):
# Go through the completion callbacks backwards until we hit the first non-completed barrier
while len(completion_callbacks) > 0:
completion_id, remaining_workflow, propagate = completion_callbacks[-1]
Expand Down
49 changes: 47 additions & 2 deletions dramatiq_workflow/tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,58 @@ def test_simple_workflow(self, time_mock):
delay=None,
)

def test_empty_chain_workflow(self):
middleware = self.broker.middleware[0]
middleware._process_completion_callbacks = mock.MagicMock()

workflow = Workflow(Chain(), broker=self.broker)
workflow.run()

self.broker.enqueue.assert_not_called()
middleware._process_completion_callbacks.assert_called_once_with(self.broker, [])

def test_empty_group_workflow(self):
middleware = self.broker.middleware[0]
middleware._process_completion_callbacks = mock.MagicMock()

workflow = Workflow(Group(), broker=self.broker)
workflow.run()

self.broker.enqueue.assert_not_called()
middleware._process_completion_callbacks.assert_called_once_with(self.broker, [])

@mock.patch("dramatiq_workflow._base.workflow_noop.message")
@mock.patch("dramatiq_workflow._base.time.time")
def test_noop_workflow(self, time_mock):
def test_empty_chain_workflow_with_delay(self, time_mock, noop_message_mock):
time_mock.return_value = 1717526000.12
workflow = Workflow(Chain(), broker=self.broker)
updated_timestamp = time_mock.return_value * 1000

original_noop_message = self.__make_message(999)
noop_message_mock.return_value = original_noop_message

middleware = self.broker.middleware[0]
middleware._process_completion_callbacks = mock.MagicMock()

workflow = Workflow(WithDelay(Chain(), delay=10), broker=self.broker)
workflow.run()

middleware._process_completion_callbacks.assert_not_called()
noop_message_mock.assert_called_once_with()

self.broker.enqueue.assert_called_once()
args, kwargs = self.broker.enqueue.call_args
enqueued_message = args[0]

self.assertEqual(kwargs, {"delay": 10})
self.assertEqual(enqueued_message.message_id, original_noop_message.message_id)
self.assertEqual(enqueued_message.message_timestamp, updated_timestamp)
self.assertEqual(enqueued_message.options, {})

def test_missing_middleware(self):
self.broker = mock.MagicMock(middleware=[])
workflow = Workflow(Chain(), broker=self.broker)
with self.assertRaisesRegex(RuntimeError, "WorkflowMiddleware middleware not found"):
workflow.run()

def test_unsupported_workflow(self):
with self.assertRaises(TypeError):
Expand Down