-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Expand file tree
/
Copy pathcheckpoint_with_resume.py
More file actions
157 lines (126 loc) · 5.7 KB
/
checkpoint_with_resume.py
File metadata and controls
157 lines (126 loc) · 5.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright (c) Microsoft. All rights reserved.
"""
Sample: Checkpointing and Resuming a Workflow
Purpose:
This sample shows how to enable checkpointing for a long-running workflow
that can be paused and resumed.
What you learn:
- How to configure checkpointing storage (InMemoryCheckpointStorage for testing)
- How to resume a workflow from a checkpoint after interruption
- How to implement executor state management with checkpoint hooks
- How to handle workflow interruptions and automatic recovery
Pipeline:
This sample shows a workflow that computes factor pairs for numbers up to a given limit:
1) A start executor that receives the upper limit and creates the initial task
2) A worker executor that processes each number to find its factor pairs
3) The worker uses checkpoint hooks to save/restore its internal state
Prerequisites:
- Basic understanding of workflow concepts, including executors, edges, events, etc.
"""
import asyncio
import sys
from dataclasses import dataclass
from random import random
from typing import Any
from agent_framework import (
Executor,
InMemoryCheckpointStorage,
WorkflowBuilder,
WorkflowCheckpoint,
WorkflowContext,
handler,
)
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
@dataclass
class ComputeTask:
"""Task containing the list of numbers remaining to be processed."""
remaining_numbers: list[int]
class StartExecutor(Executor):
"""Initiates the workflow by providing the upper limit for factor pair computation."""
@handler
async def start(self, upper_limit: int, ctx: WorkflowContext[ComputeTask]) -> None:
"""Start the workflow with a list of numbers to process."""
print(f"StartExecutor: Starting factor pair computation up to {upper_limit}")
await ctx.send_message(ComputeTask(remaining_numbers=list(range(1, upper_limit + 1))))
class WorkerExecutor(Executor):
"""Processes numbers to compute their factor pairs and manages executor state for checkpointing."""
def __init__(self, id: str) -> None:
super().__init__(id=id)
self._composite_number_pairs: dict[int, list[tuple[int, int]]] = {}
@handler
async def compute(
self,
task: ComputeTask,
ctx: WorkflowContext[ComputeTask, dict[int, list[tuple[int, int]]]],
) -> None:
"""Process the next number in the task, computing its factor pairs."""
next_number = task.remaining_numbers.pop(0)
print(f"WorkerExecutor: Computing factor pairs for {next_number}")
pairs: list[tuple[int, int]] = []
for i in range(1, next_number):
if next_number % i == 0:
pairs.append((i, next_number // i))
self._composite_number_pairs[next_number] = pairs
if not task.remaining_numbers:
# All numbers processed - output the results
await ctx.yield_output(self._composite_number_pairs)
else:
# More numbers to process - continue with remaining task
await ctx.send_message(task)
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Save the executor's internal state for checkpointing."""
return {"composite_number_pairs": self._composite_number_pairs}
@override
async def on_checkpoint_restore(self, state: dict[str, Any]) -> None:
"""Restore the executor's internal state from a checkpoint."""
self._composite_number_pairs = state.get("composite_number_pairs", {})
async def main():
# Build workflow with checkpointing enabled
checkpoint_storage = InMemoryCheckpointStorage()
start = StartExecutor(id="start")
worker = WorkerExecutor(id="worker")
workflow_builder = (
WorkflowBuilder(start_executor=start, checkpoint_storage=checkpoint_storage)
.add_edge(start, worker)
.add_edge(worker, worker) # Self-loop for iterative processing
)
# Run workflow with automatic checkpoint recovery
latest_checkpoint: WorkflowCheckpoint | None = None
while True:
workflow = workflow_builder.build()
# Start from checkpoint or fresh execution
print(f"\n** Workflow {workflow.id} started **")
event_stream = (
workflow.run(message=10, stream=True)
if latest_checkpoint is None
else workflow.run(checkpoint_id=latest_checkpoint.checkpoint_id, stream=True)
)
output: str | None = None
async for event in event_stream:
if event.type == "output":
output = event.data
break
if event.type == "superstep_completed" and random() < 0.5:
# Randomly simulate system interruptions
# The type="superstep_completed" event ensures we only interrupt after
# the current super-step is fully complete and checkpointed.
# If we interrupt mid-step, the workflow may resume from an earlier point.
print("\n** Simulating workflow interruption. Stopping execution. **")
break
# Find the latest checkpoint to resume from
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=workflow.name)
if not latest_checkpoint:
raise RuntimeError("No checkpoints available to resume from.")
print(
f"Checkpoint {latest_checkpoint.checkpoint_id}: "
f"(iter={latest_checkpoint.iteration_count}, messages={latest_checkpoint.messages})"
)
if output is not None:
print(f"\nWorkflow completed successfully with output: {output}")
break
if __name__ == "__main__":
asyncio.run(main())