Skip to content

Commit 6f2df1d

Browse files
committed
feat: add BatchResult serialization support with dedicated codec
1 parent a950699 commit 6f2df1d

File tree

9 files changed

+770
-226
lines changed

9 files changed

+770
-226
lines changed

src/aws_durable_execution_sdk_python/concurrency.py

Lines changed: 28 additions & 197 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import threading
88
import time
99
from abc import ABC, abstractmethod
10-
from collections import Counter
1110
from concurrent.futures import Future, ThreadPoolExecutor
1211
from dataclasses import dataclass
1312
from enum import Enum
@@ -22,7 +21,11 @@
2221
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
2322
from aws_durable_execution_sdk_python.lambda_service import ErrorObject
2423
from aws_durable_execution_sdk_python.operation.child import child_handler
25-
from aws_durable_execution_sdk_python.types import BatchResult as BatchResultProtocol
24+
from aws_durable_execution_sdk_python.types import (
25+
BatchItem,
26+
BatchItemStatus,
27+
BatchResult,
28+
)
2629

2730
if TYPE_CHECKING:
2831
from collections.abc import Callable
@@ -45,18 +48,6 @@
4548

4649

4750
# region Result models
48-
class BatchItemStatus(Enum):
49-
SUCCEEDED = "SUCCEEDED"
50-
FAILED = "FAILED"
51-
STARTED = "STARTED"
52-
53-
54-
class CompletionReason(Enum):
55-
ALL_COMPLETED = "ALL_COMPLETED"
56-
MIN_SUCCESSFUL_REACHED = "MIN_SUCCESSFUL_REACHED"
57-
FAILURE_TOLERANCE_EXCEEDED = "FAILURE_TOLERANCE_EXCEEDED"
58-
59-
6051
@dataclass(frozen=True)
6152
class SuspendResult:
6253
should_suspend: bool
@@ -71,173 +62,6 @@ def suspend(exception: SuspendExecution) -> SuspendResult:
7162
return SuspendResult(should_suspend=True, exception=exception)
7263

7364

74-
@dataclass(frozen=True)
75-
class BatchItem(Generic[R]):
76-
index: int
77-
status: BatchItemStatus
78-
result: R | None = None
79-
error: ErrorObject | None = None
80-
81-
def to_dict(self) -> dict:
82-
return {
83-
"index": self.index,
84-
"status": self.status.value,
85-
"result": self.result,
86-
"error": self.error.to_dict() if self.error else None,
87-
}
88-
89-
@classmethod
90-
def from_dict(cls, data: dict) -> BatchItem[R]:
91-
return cls(
92-
index=data["index"],
93-
status=BatchItemStatus(data["status"]),
94-
result=data.get("result"),
95-
error=ErrorObject.from_dict(data["error"]) if data.get("error") else None,
96-
)
97-
98-
99-
@dataclass(frozen=True)
100-
class BatchResult(Generic[R], BatchResultProtocol[R]): # noqa: PYI059
101-
all: list[BatchItem[R]]
102-
completion_reason: CompletionReason
103-
104-
@classmethod
105-
def from_dict(
106-
cls, data: dict, completion_config: CompletionConfig | None = None
107-
) -> BatchResult[R]:
108-
batch_items: list[BatchItem[R]] = [
109-
BatchItem.from_dict(item) for item in data["all"]
110-
]
111-
112-
completion_reason_value = data.get("completionReason")
113-
if completion_reason_value is None:
114-
# Infer completion reason from batch item statuses and completion config
115-
# This aligns with the TypeScript implementation that uses completion config
116-
# to accurately reconstruct the completion reason during replay
117-
result = cls.from_items(batch_items, completion_config)
118-
logger.warning(
119-
"Missing completionReason in BatchResult deserialization, "
120-
"inferred '%s' from batch item statuses. "
121-
"This may indicate incomplete serialization data.",
122-
result.completion_reason.value,
123-
)
124-
return result
125-
126-
completion_reason = CompletionReason(completion_reason_value)
127-
return cls(batch_items, completion_reason)
128-
129-
@classmethod
130-
def from_items(
131-
cls,
132-
items: list[BatchItem[R]],
133-
completion_config: CompletionConfig | None = None,
134-
):
135-
"""
136-
Infer completion reason based on batch item statuses and completion config.
137-
138-
This follows the same logic as the TypeScript implementation:
139-
- If all items completed: ALL_COMPLETED
140-
- If minSuccessful threshold met and not all completed: MIN_SUCCESSFUL_REACHED
141-
- Otherwise: FAILURE_TOLERANCE_EXCEEDED
142-
"""
143-
144-
statuses = (item.status for item in items)
145-
counts = Counter(statuses)
146-
succeeded_count = counts.get(BatchItemStatus.SUCCEEDED, 0)
147-
failed_count = counts.get(BatchItemStatus.FAILED, 0)
148-
started_count = counts.get(BatchItemStatus.STARTED, 0)
149-
150-
completed_count = succeeded_count + failed_count
151-
total_count = started_count + completed_count
152-
153-
# If all items completed (no started items), it's ALL_COMPLETED
154-
if completed_count == total_count:
155-
completion_reason = CompletionReason.ALL_COMPLETED
156-
elif ( # If we have completion config and minSuccessful threshold is met
157-
completion_config
158-
and (min_successful := completion_config.min_successful) is not None
159-
and succeeded_count >= min_successful
160-
):
161-
completion_reason = CompletionReason.MIN_SUCCESSFUL_REACHED
162-
else:
163-
# Otherwise, assume failure tolerance was exceeded
164-
completion_reason = CompletionReason.FAILURE_TOLERANCE_EXCEEDED
165-
166-
return cls(items, completion_reason)
167-
168-
def to_dict(self) -> dict:
169-
return {
170-
"all": [item.to_dict() for item in self.all],
171-
"completionReason": self.completion_reason.value,
172-
}
173-
174-
def succeeded(self) -> list[BatchItem[R]]:
175-
return [
176-
item
177-
for item in self.all
178-
if item.status is BatchItemStatus.SUCCEEDED and item.result is not None
179-
]
180-
181-
def failed(self) -> list[BatchItem[R]]:
182-
return [
183-
item
184-
for item in self.all
185-
if item.status is BatchItemStatus.FAILED and item.error is not None
186-
]
187-
188-
def started(self) -> list[BatchItem[R]]:
189-
return [item for item in self.all if item.status is BatchItemStatus.STARTED]
190-
191-
@property
192-
def status(self) -> BatchItemStatus:
193-
return BatchItemStatus.FAILED if self.has_failure else BatchItemStatus.SUCCEEDED
194-
195-
@property
196-
def has_failure(self) -> bool:
197-
return any(item.status is BatchItemStatus.FAILED for item in self.all)
198-
199-
def throw_if_error(self) -> None:
200-
first_error = next(
201-
(item.error for item in self.all if item.status is BatchItemStatus.FAILED),
202-
None,
203-
)
204-
if first_error:
205-
raise first_error.to_callable_runtime_error()
206-
207-
def get_results(self) -> list[R]:
208-
return [
209-
item.result
210-
for item in self.all
211-
if item.status is BatchItemStatus.SUCCEEDED and item.result is not None
212-
]
213-
214-
def get_errors(self) -> list[ErrorObject]:
215-
return [
216-
item.error
217-
for item in self.all
218-
if item.status is BatchItemStatus.FAILED and item.error is not None
219-
]
220-
221-
@property
222-
def success_count(self) -> int:
223-
return sum(1 for item in self.all if item.status is BatchItemStatus.SUCCEEDED)
224-
225-
@property
226-
def failure_count(self) -> int:
227-
return sum(1 for item in self.all if item.status is BatchItemStatus.FAILED)
228-
229-
@property
230-
def started_count(self) -> int:
231-
return sum(1 for item in self.all if item.status is BatchItemStatus.STARTED)
232-
233-
@property
234-
def total_count(self) -> int:
235-
return len(self.all)
236-
237-
238-
# endregion Result models
239-
240-
24165
# region concurrency models
24266
@dataclass(frozen=True)
24367
class Executable(Generic[CallableType]):
@@ -367,12 +191,12 @@ class ExecutionCounters:
367191
def __init__(
368192
self,
369193
total_tasks: int,
370-
min_successful: int,
194+
min_successful: int | None,
371195
tolerated_failure_count: int | None,
372196
tolerated_failure_percentage: float | None,
373197
):
374198
self.total_tasks: int = total_tasks
375-
self.min_successful: int = min_successful
199+
self.min_successful: int | None = min_successful
376200
self.tolerated_failure_count: int | None = tolerated_failure_count
377201
self.tolerated_failure_percentage: float | None = tolerated_failure_percentage
378202
self.success_count: int = 0
@@ -421,24 +245,26 @@ def is_complete(self) -> bool:
421245
"""
422246
Check if execution should complete (based on completion criteria).
423247
Matches TypeScript isComplete() logic.
248+
249+
Note: This method only checks completion criteria (all done, or min_successful met).
250+
Failure tolerance is enforced separately by should_continue() and combined in should_complete().
424251
"""
425252
with self._lock:
426253
completed_count = self.success_count + self.failure_count
427254

428255
# All tasks completed
429256
if completed_count == self.total_tasks:
430-
# Complete if no failure tolerance OR no failures OR min successful reached
431-
return (
432-
(
433-
self.tolerated_failure_count is None
434-
and self.tolerated_failure_percentage is None
435-
)
436-
or self.failure_count == 0
437-
or self.success_count >= self.min_successful
438-
)
257+
# If min_successful is explicitly set, check if we met it
258+
# Otherwise, complete when all tasks are done
259+
if self.min_successful is not None:
260+
return self.success_count >= self.min_successful
261+
return True
439262

440-
# when we breach min successful, we've completed
441-
return self.success_count >= self.min_successful
263+
# Early completion: when we breach min_successful (only if explicitly set)
264+
return (
265+
self.min_successful is not None
266+
and self.success_count >= self.min_successful
267+
)
442268

443269
def should_complete(self) -> bool:
444270
"""
@@ -453,9 +279,12 @@ def is_all_completed(self) -> bool:
453279
return self.success_count == self.total_tasks
454280

455281
def is_min_successful_reached(self) -> bool:
456-
"""True if minimum successful tasks reached."""
282+
"""True if minimum successful task is both set and reached."""
457283
with self._lock:
458-
return self.success_count >= self.min_successful
284+
return (
285+
self.min_successful is not None
286+
and self.success_count >= self.min_successful
287+
)
459288

460289
def is_failure_tolerance_exceeded(self) -> bool:
461290
"""True if failure tolerance was exceeded."""
@@ -594,7 +423,9 @@ def __init__(
594423
self._suspend_exception: SuspendExecution | None = None
595424

596425
# ExecutionCounters will keep track of completion criteria and on-going counters
597-
min_successful = self.completion_config.min_successful or len(self.executables)
426+
# Note: min_successful should remain None if not explicitly set
427+
# When None, the operation completes when all tasks finish (respecting failure tolerance)
428+
min_successful = self.completion_config.min_successful
598429
tolerated_failure_count = self.completion_config.tolerated_failure_count
599430
tolerated_failure_percentage = (
600431
self.completion_config.tolerated_failure_percentage

src/aws_durable_execution_sdk_python/operation/map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def from_items(
8282
name_prefix="map-item-",
8383
serdes=config.serdes,
8484
summary_generator=config.summary_generator,
85+
item_serdes=config.item_serdes,
8586
)
8687

8788
def execute_item(self, child_context, executable: Executable[Callable]) -> R:

src/aws_durable_execution_sdk_python/operation/parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
from aws_durable_execution_sdk_python.lambda_service import OperationSubType
1313

1414
if TYPE_CHECKING:
15-
from aws_durable_execution_sdk_python.concurrency import BatchResult
1615
from aws_durable_execution_sdk_python.context import DurableContext
1716
from aws_durable_execution_sdk_python.identifier import OperationIdentifier
1817
from aws_durable_execution_sdk_python.serdes import SerDes
1918
from aws_durable_execution_sdk_python.state import ExecutionState
20-
from aws_durable_execution_sdk_python.types import SummaryGenerator
19+
from aws_durable_execution_sdk_python.types import BatchResult, SummaryGenerator
2120

2221
logger = logging.getLogger(__name__)
2322

@@ -69,6 +68,7 @@ def from_callables(
6968
name_prefix="parallel-branch-",
7069
serdes=config.serdes,
7170
summary_generator=config.summary_generator,
71+
item_serdes=config.item_serdes,
7272
)
7373

7474
def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301

src/aws_durable_execution_sdk_python/serdes.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ExecutionError,
3838
SerDesError,
3939
)
40+
from aws_durable_execution_sdk_python.types import BatchResult
4041

4142
logger = logging.getLogger(__name__)
4243

@@ -62,6 +63,7 @@ class TypeTag(StrEnum):
6263
TUPLE = "t"
6364
LIST = "l"
6465
DICT = "m"
66+
BATCH_RESULT = "br"
6567

6668

6769
@dataclass(frozen=True)
@@ -207,6 +209,12 @@ def dispatcher(self):
207209
def encode(self, obj: Any) -> EncodedValue:
208210
"""Encode container using dispatcher for recursive elements."""
209211
match obj:
212+
case BatchResult():
213+
# Encode BatchResult as dict with special tag
214+
return EncodedValue(
215+
TypeTag.BATCH_RESULT,
216+
self._wrap(obj.to_dict(), self.dispatcher).value,
217+
)
210218
case list():
211219
return EncodedValue(
212220
TypeTag.LIST, [self._wrap(v, self.dispatcher) for v in obj]
@@ -231,6 +239,10 @@ def encode(self, obj: Any) -> EncodedValue:
231239
def decode(self, tag: TypeTag, value: Any) -> Any:
232240
"""Decode container using dispatcher for recursive elements."""
233241
match tag:
242+
case TypeTag.BATCH_RESULT:
243+
# Decode as dict (handles all recursive unwrapping) then reconstruct
244+
decoded_dict = self.decode(TypeTag.DICT, value)
245+
return BatchResult.from_dict(decoded_dict)
234246
case TypeTag.LIST:
235247
if not isinstance(value, list):
236248
msg = f"Expected list, got {type(value)}"
@@ -292,7 +304,7 @@ def encode(self, obj: Any) -> EncodedValue:
292304
return self.decimal_codec.encode(obj)
293305
case datetime() | date():
294306
return self.datetime_codec.encode(obj)
295-
case list() | tuple() | dict():
307+
case BatchResult() | list() | tuple() | dict():
296308
return self.container_codec.encode(obj)
297309
case _:
298310
msg = f"Unsupported type: {type(obj)}"
@@ -316,7 +328,7 @@ def decode(self, tag: TypeTag, value: Any) -> Any:
316328
return self.decimal_codec.decode(tag, value)
317329
case TypeTag.DATETIME | TypeTag.DATE:
318330
return self.datetime_codec.decode(tag, value)
319-
case TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT:
331+
case TypeTag.BATCH_RESULT | TypeTag.LIST | TypeTag.TUPLE | TypeTag.DICT:
320332
return self.container_codec.decode(tag, value)
321333
case _:
322334
msg = f"Unknown type tag: {tag}"

0 commit comments

Comments
 (0)